fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -1,8 +1,44 @@
1
+ """
2
+ Base algorithm classes for model fusion.
3
+
4
+ This module provides the foundational abstract base class for implementing model fusion
5
+ algorithms in the FusionBench framework. It defines the standard interface and lifecycle
6
+ hooks that all fusion algorithms should follow.
7
+
8
+ The main class `BaseAlgorithm` serves as a template for creating various model fusion
9
+ strategies such as simple averaging, task arithmetic, weight interpolation, and more
10
+ advanced techniques. It integrates with the YAML configuration system and provides
11
+ hooks for setup and cleanup operations.
12
+
13
+ Classes:
14
+ BaseAlgorithm: Abstract base class for all model fusion algorithms.
15
+ BaseModelFusionAlgorithm: Alias for BaseAlgorithm (backward compatibility).
16
+
17
+ Example:
18
+ Implementing a custom fusion algorithm:
19
+
20
+ >>> from fusion_bench.method.base_algorithm import BaseAlgorithm
21
+ >>> from fusion_bench.modelpool import BaseModelPool
22
+ >>>
23
+ >>> class WeightedAverageAlgorithm(BaseAlgorithm):
24
+ ... def __init__(self, weights=None, **kwargs):
25
+ ... self.register_parameter_to_config("weights", "weights", weights or [])
26
+ ... super().__init__(**kwargs)
27
+ ...
28
+ ... def run(self, modelpool: BaseModelPool):
29
+ ... models = list(modelpool)
30
+ ... if len(self.weights) != len(models):
31
+ ... raise ValueError("Number of weights must match number of models")
32
+ ...
33
+ ... # Implement weighted averaging logic here
34
+ ... return fused_model
35
+ """
36
+
1
37
  import logging
2
38
  from abc import abstractmethod
3
39
  from typing import Optional # noqa: F401
4
40
 
5
- from fusion_bench.mixins import BaseYAMLSerializableModel
41
+ from fusion_bench.mixins import BaseYAMLSerializable
6
42
  from fusion_bench.modelpool import BaseModelPool
7
43
 
8
44
  __all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
@@ -10,36 +46,183 @@ __all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
10
46
  log = logging.getLogger(__name__)
11
47
 
12
48
 
13
- class BaseAlgorithm(BaseYAMLSerializableModel):
49
+ class BaseAlgorithm(BaseYAMLSerializable):
14
50
  """
15
51
  Base class for model fusion algorithms.
16
52
 
17
- This class provides a template for implementing model fusion algorithms.
18
- Subclasses must implement the `run` method to define the fusion logic.
53
+ This abstract class provides a standardized interface for implementing model fusion
54
+ algorithms. It inherits from BaseYAMLSerializable to support configuration loading
55
+ from YAML files.
56
+
57
+ The class follows a template method pattern where subclasses must implement the
58
+ core fusion logic in the `run` method, while optional lifecycle hooks allow for
59
+ setup and cleanup operations.
60
+
61
+ Attributes:
62
+ _program: Optional program reference for algorithm execution context.
63
+ _config_key (str): Configuration key used for YAML serialization, defaults to "method".
64
+
65
+ Examples:
66
+ Creating a simple averaging algorithm:
67
+
68
+ >>> class SimpleAverageAlgorithm(BaseAlgorithm):
69
+ ... def run(self, modelpool: BaseModelPool):
70
+ ... # Implementation of model averaging logic
71
+ ... return averaged_model
72
+ ...
73
+ >>> algorithm = SimpleAverageAlgorithm()
74
+ >>> merged_model = algorithm.run(modelpool)
75
+
76
+ Loading algorithm from YAML configuration:
77
+
78
+ >>> algorithm = BaseAlgorithm.from_yaml("config.yaml")
79
+ >>> result = algorithm.run(modelpool)
80
+
81
+ Note:
82
+ Subclasses must implement the abstract `run` method to define the specific
83
+ fusion strategy (e.g., simple averaging, task arithmetic, etc.).
19
84
  """
20
85
 
21
86
  _program = None
22
87
  _config_key = "method"
23
88
 
89
+ def on_run_start(self):
90
+ """
91
+ Lifecycle hook called at the beginning of algorithm execution.
92
+
93
+ This method is invoked before the main `run` method executes, providing
94
+ an opportunity for subclasses to perform initialization tasks such as:
95
+
96
+ - Setting up logging or monitoring
97
+ - Initializing algorithm-specific state
98
+ - Validating prerequisites
99
+ - Preparing computational resources
100
+
101
+ The default implementation does nothing, allowing subclasses to override
102
+ as needed for their specific requirements.
103
+
104
+ Examples:
105
+ >>> class MyAlgorithm(BaseAlgorithm):
106
+ ... def on_run_start(self):
107
+ ... super().on_run_start()
108
+ ... print("Starting model fusion...")
109
+ ... self.start_time = time.time()
110
+ """
111
+ pass
112
+
113
+ def on_run_end(self):
114
+ """
115
+ Lifecycle hook called at the end of algorithm execution.
116
+
117
+ This method is invoked after the main `run` method completes, providing
118
+ an opportunity for subclasses to perform cleanup and finalization tasks such as:
119
+
120
+ - Logging execution statistics or results
121
+ - Cleaning up temporary resources
122
+ - Saving intermediate results or metrics
123
+ - Releasing computational resources
124
+
125
+ The method is called regardless of whether the `run` method succeeded or failed,
126
+ making it suitable for cleanup operations that should always occur.
127
+
128
+ The default implementation does nothing, allowing subclasses to override
129
+ as needed for their specific requirements.
130
+
131
+ Examples:
132
+ >>> class MyAlgorithm(BaseAlgorithm):
133
+ ... def on_run_end(self):
134
+ ... super().on_run_end()
135
+ ... elapsed = time.time() - self.start_time
136
+ ... print(f"Fusion completed in {elapsed:.2f}s")
137
+ """
138
+ pass
139
+
24
140
  @abstractmethod
25
141
  def run(self, modelpool: BaseModelPool):
26
142
  """
27
- Fuse the models in the given model pool.
143
+ Execute the model fusion algorithm on the provided model pool.
144
+
145
+ This is the core method that must be implemented by all subclasses to define
146
+ their specific fusion strategy. The method takes a pool of models and produces
147
+ a fused result according to the algorithm's logic.
148
+
149
+ Args:
150
+ modelpool (BaseModelPool): A collection of models to be fused. The modelpool
151
+ provides access to individual models and their metadata, allowing the
152
+ algorithm to iterate over models, access their parameters, and perform
153
+ fusion operations.
154
+
155
+ Returns:
156
+ The type of return value depends on the specific algorithm implementation.
157
+ Common return types include:
158
+
159
+ - A single fused model (torch.nn.Module)
160
+ - A dictionary of fused models for multi-task scenarios
161
+ - Fusion results with additional metadata
162
+ - Custom data structures specific to the algorithm
28
163
 
29
- This method must be implemented by subclasses to define the fusion logic.
164
+ Raises:
165
+ NotImplementedError: If called on the base class without implementation.
166
+ ValueError: If the modelpool is invalid or incompatible with the algorithm.
167
+ RuntimeError: If fusion fails due to model incompatibilities or other issues.
30
168
 
31
169
  Examples:
32
- >>> algorithm = SimpleAverageAlgorithm()
33
- >>> modelpool = ModelPool()
34
- >>> merged_model = algorithm.run(modelpool)
170
+ Simple averaging implementation:
35
171
 
36
- Args:
37
- modelpool (BaseModelPool): The pool of models to fuse.
172
+ >>> def run(self, modelpool: BaseModelPool):
173
+ ... models = [model for model in modelpool]
174
+ ... averaged_params = {}
175
+ ... for name in models[0].state_dict():
176
+ ... averaged_params[name] = torch.stack([
177
+ ... model.state_dict()[name] for model in models
178
+ ... ]).mean(dim=0)
179
+ ... fused_model = copy.deepcopy(models[0])
180
+ ... fused_model.load_state_dict(averaged_params)
181
+ ... return fused_model
182
+
183
+ Task arithmetic implementation:
184
+
185
+ >>> def run(self, modelpool: BaseModelPool):
186
+ ... pretrained = modelpool.get_model('pretrained')
187
+ ... task_vectors = []
188
+ ... for model_name in modelpool.model_names:
189
+ ... if model_name != 'pretrained':
190
+ ... task_vector = self.compute_task_vector(
191
+ ... modelpool.get_model(model_name), pretrained
192
+ ... )
193
+ ... task_vectors.append(task_vector)
194
+ ... return self.merge_task_vectors(pretrained, task_vectors)
195
+
196
+ Note:
197
+ - The modelpool iteration order may affect results for non-commutative operations
198
+ - Ensure model compatibility (architecture, parameter shapes) before fusion
199
+ - Consider memory constraints when loading multiple large models
200
+ - Use appropriate device placement for GPU/CPU computation
38
201
  """
39
202
  pass
40
203
 
41
204
 
42
205
  BaseModelFusionAlgorithm = BaseAlgorithm
43
206
  """
44
- Alias for `BaseAlgorithm`.
207
+ Alias for BaseAlgorithm class.
208
+
209
+ This alias is provided for backward compatibility and semantic clarity.
210
+ Some users may prefer the more explicit name 'BaseModelFusionAlgorithm'
211
+ to emphasize that this class is specifically designed for model fusion
212
+ tasks, while others may prefer the shorter 'BaseAlgorithm' name.
213
+
214
+ Both names refer to the exact same class and can be used interchangeably.
215
+
216
+ Examples:
217
+ Using the original name:
218
+ >>> class MyAlgorithm(BaseAlgorithm):
219
+ ... def run(self, modelpool): pass
220
+
221
+ Using the alias:
222
+ >>> class MyAlgorithm(BaseModelFusionAlgorithm):
223
+ ... def run(self, modelpool): pass
224
+
225
+ Note:
226
+ The alias is maintained for compatibility but BaseAlgorithm is the
227
+ preferred name for new implementations.
45
228
  """
@@ -0,0 +1,4 @@
1
+ """
2
+ Adapted from https://github.com/FasterDecoding/BitDelta
3
+ """
4
+ from .bitdelta import BitDeltaAlgorithm
@@ -0,0 +1,156 @@
1
+ import logging
2
+ import os
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from tqdm.auto import tqdm
7
+
8
+ from fusion_bench import BaseAlgorithm, BaseModelPool
9
+ from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
10
+ from fusion_bench.modelpool import CausalLMPool
11
+
12
+ from .bitdelta_utils.data import get_dataloader, get_dataset
13
+ from .bitdelta_utils.diff import compress_diff, save_diff, save_full_model
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ class BitDeltaAlgorithm(
19
+ BaseAlgorithm,
20
+ LightningFabricMixin,
21
+ SimpleProfilerMixin,
22
+ ):
23
+ _config_mapping = BaseAlgorithm._config_mapping | {
24
+ "save_dir": "save_dir",
25
+ "save_full_model": "save_full_model",
26
+ "lr": "lr",
27
+ "batch_size": "batch_size",
28
+ "num_steps": "num_steps",
29
+ "dataset_name": "dataset_name",
30
+ "subset": "subset",
31
+ "split": "split",
32
+ "max_length": "max_length",
33
+ }
34
+
35
+ def __init__(
36
+ self,
37
+ save_dir: str,
38
+ save_full_model: bool = False,
39
+ lr: float = 1e-4,
40
+ batch_size: int = 4,
41
+ num_steps: int = 100,
42
+ dataset_name: str = "c4",
43
+ subset: str = "en",
44
+ split: str = "train",
45
+ max_length: int = 128,
46
+ **kwargs,
47
+ ):
48
+ super().__init__(**kwargs)
49
+ self.save_dir = save_dir
50
+ self.save_full_model = save_full_model
51
+ self.lr = lr
52
+ self.batch_size = batch_size
53
+ self.num_steps = num_steps
54
+ self.dataset_name = dataset_name
55
+ self.subset = subset
56
+ self.split = split
57
+ self.max_length = max_length
58
+
59
+ def run(self, modelpool: CausalLMPool):
60
+ if self.save_dir is None:
61
+ log.info(
62
+ f"save_dir not set, using log_dir instead. log_dir: {self.log_dir}"
63
+ )
64
+ self.save_dir = self.log_dir
65
+
66
+ with self.profile("model loading"):
67
+ tokenizer = modelpool.load_tokenizer()
68
+ base_model = modelpool.load_pretrained_model()
69
+ finetuned_model = modelpool.load_model(modelpool.model_names[0])
70
+ finetuned_compressed_model = modelpool.load_model(modelpool.model_names[0])
71
+
72
+ with self.profile("model compression"):
73
+ print(f"compressing diff...")
74
+ compress_diff(base_model, finetuned_model, finetuned_compressed_model)
75
+
76
+ # save untrained delta
77
+ save_diff(
78
+ finetuned_compressed_model, os.path.join(self.save_dir, "diff_untrained.pt")
79
+ )
80
+
81
+ optimizer = torch.optim.AdamW(
82
+ finetuned_compressed_model.parameters(), lr=self.lr
83
+ )
84
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
85
+ optimizer, self.num_steps
86
+ )
87
+
88
+ train_num_samples = self.batch_size * self.num_steps
89
+ train_dataset = get_dataset(
90
+ self.dataset_name,
91
+ self.subset,
92
+ "train",
93
+ size=train_num_samples,
94
+ )
95
+ train_dataloader = get_dataloader(
96
+ train_dataset,
97
+ tokenizer,
98
+ self.batch_size,
99
+ num_workers=4,
100
+ max_length=self.max_length,
101
+ )
102
+
103
+ bar = tqdm(train_dataloader)
104
+
105
+ train_loss_list = []
106
+
107
+ # Train loop
108
+ for step, batch in enumerate(bar):
109
+ batch1 = {k: v.to(finetuned_model.device) for k, v in batch.items()}
110
+ with torch.inference_mode():
111
+ finetuned_outputs = finetuned_model(**batch1)
112
+
113
+ batch2 = {
114
+ k: v.to(finetuned_compressed_model.device) for k, v in batch.items()
115
+ }
116
+ finetuned_compressed_outputs = finetuned_compressed_model(**batch2)
117
+
118
+ loss = F.mse_loss(
119
+ finetuned_outputs.logits.clone().to(
120
+ finetuned_compressed_outputs.logits.device
121
+ ),
122
+ finetuned_compressed_outputs.logits,
123
+ )
124
+
125
+ train_loss_list.append(loss.item())
126
+
127
+ optimizer.zero_grad()
128
+ loss.backward()
129
+ optimizer.step()
130
+ scheduler.step()
131
+
132
+ bar.set_description(f"train loss: {loss.item()}")
133
+
134
+ # save trained delta
135
+ save_diff(finetuned_compressed_model, os.path.join(self.save_dir, "diff.pt"))
136
+
137
+ if self.save_full_model:
138
+ print("saving uncalibrated model")
139
+ save_full_model(
140
+ base_model,
141
+ tokenizer,
142
+ os.path.join(self.save_dir, "diff_untrained.pt"),
143
+ os.path.join(self.save_dir, "uncalibrated_model"),
144
+ device="cpu",
145
+ )
146
+ print("saving calibrated model")
147
+ save_full_model(
148
+ base_model,
149
+ tokenizer,
150
+ os.path.join(self.save_dir, "diff.pt"),
151
+ os.path.join(self.save_dir, "calibrated_model"),
152
+ device="cpu",
153
+ )
154
+
155
+ del base_model, finetuned_model, finetuned_compressed_model
156
+ torch.cuda.empty_cache()