fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (188) hide show
  1. fusion_bench/__init__.py +22 -2
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +6 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/constants/runtime.py +57 -0
  9. fusion_bench/dataset/clip_dataset.py +2 -1
  10. fusion_bench/dataset/gpt2_glue.py +9 -9
  11. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  12. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  13. fusion_bench/dataset/image_dataset.py +1 -1
  14. fusion_bench/dataset/nyuv2.py +2 -2
  15. fusion_bench/method/__init__.py +24 -5
  16. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  17. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  18. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  19. fusion_bench/method/base_algorithm.py +195 -12
  20. fusion_bench/method/bitdelta/__init__.py +5 -0
  21. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  25. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  26. fusion_bench/method/classification/clip_finetune.py +1 -1
  27. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  28. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  29. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  30. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  31. fusion_bench/method/ensemble.py +12 -12
  32. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  33. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
  34. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  35. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  36. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  37. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  38. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  39. fusion_bench/method/linear/expo.py +2 -1
  40. fusion_bench/method/linear/linear_interpolation.py +6 -4
  41. fusion_bench/method/linear/simple_average_for_llama.py +17 -13
  42. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  43. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  44. fusion_bench/method/model_recombination.py +2 -5
  45. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  46. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  47. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  48. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  49. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  50. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  51. fusion_bench/method/randes/modelsoup.py +1 -3
  52. fusion_bench/method/regmean/clip_regmean.py +2 -2
  53. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  54. fusion_bench/method/regmean/regmean.py +2 -11
  55. fusion_bench/method/regmean_plusplus/__init__.py +1 -1
  56. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
  57. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
  58. fusion_bench/method/simple_average.py +12 -16
  59. fusion_bench/method/slerp/slerp.py +5 -2
  60. fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
  61. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  62. fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
  63. fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
  64. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
  65. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  66. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  67. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  68. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  69. fusion_bench/method/we_moe/__init__.py +1 -0
  70. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  71. fusion_bench/method/we_moe/entropy_loss.py +25 -0
  72. fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
  73. fusion_bench/method/we_moe/utils.py +15 -0
  74. fusion_bench/method/we_moe/we_moe.py +6 -6
  75. fusion_bench/method/weighted_average/llama.py +4 -16
  76. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  77. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  78. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  79. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  80. fusion_bench/mixins/__init__.py +10 -2
  81. fusion_bench/mixins/clip_classification.py +15 -45
  82. fusion_bench/mixins/hydra_config.py +105 -7
  83. fusion_bench/mixins/lightning_fabric.py +2 -0
  84. fusion_bench/mixins/serialization.py +275 -48
  85. fusion_bench/modelpool/__init__.py +2 -2
  86. fusion_bench/modelpool/base_pool.py +29 -9
  87. fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
  88. fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
  89. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  90. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  91. fusion_bench/models/__init__.py +7 -1
  92. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  93. fusion_bench/models/hf_utils.py +160 -0
  94. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  95. fusion_bench/models/linearized/vision_model.py +1 -1
  96. fusion_bench/models/model_card_templates/default.md +46 -0
  97. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  98. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  99. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  100. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  101. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  102. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  103. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  104. fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
  105. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  106. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
  107. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  108. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  109. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  110. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
  111. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  112. fusion_bench/models/parameter_dict.py +1 -1
  113. fusion_bench/models/sparse_we_moe.py +1 -53
  114. fusion_bench/models/utils.py +26 -0
  115. fusion_bench/models/we_moe.py +1 -53
  116. fusion_bench/models/wrappers/ensemble.py +6 -4
  117. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  118. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  119. fusion_bench/programs/base_program.py +81 -2
  120. fusion_bench/programs/fabric_fusion_program.py +46 -61
  121. fusion_bench/scripts/cli.py +38 -5
  122. fusion_bench/taskpool/base_pool.py +4 -3
  123. fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
  124. fusion_bench/taskpool/dummy.py +1 -1
  125. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  126. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  127. fusion_bench/utils/__init__.py +7 -1
  128. fusion_bench/utils/cache_utils.py +101 -1
  129. fusion_bench/utils/devices.py +14 -4
  130. fusion_bench/utils/fabric.py +2 -2
  131. fusion_bench/utils/instantiate_utils.py +3 -1
  132. fusion_bench/utils/lazy_imports.py +23 -0
  133. fusion_bench/utils/lazy_state_dict.py +38 -3
  134. fusion_bench/utils/modelscope.py +127 -8
  135. fusion_bench/utils/parameters.py +2 -2
  136. fusion_bench/utils/path.py +56 -0
  137. fusion_bench/utils/pylogger.py +1 -1
  138. fusion_bench/utils/rich_utils.py +3 -0
  139. fusion_bench/utils/state_dict_arithmetic.py +25 -23
  140. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
  141. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
  142. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  143. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  144. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  145. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  146. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  147. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  148. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  149. fusion_bench_config/hydra/default.yaml +6 -2
  150. fusion_bench_config/llama_full_finetune.yaml +1 -0
  151. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  152. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  153. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  154. fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
  155. fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
  156. fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
  157. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  158. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  159. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
  160. fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
  167. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
  168. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  169. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  170. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  171. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  172. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  173. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  174. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  175. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  176. fusion_bench_config/nyuv2_config.yaml +3 -1
  177. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  178. fusion_bench_config/path/default.yaml +28 -0
  179. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  180. fusion_bench_config/method/adamerging.yaml +0 -23
  181. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  182. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  183. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  184. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
  185. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
  186. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
  187. {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
  188. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -22,6 +22,7 @@ import torch
22
22
  from torch import Tensor, nn
23
23
  from torch.func import functional_call
24
24
 
25
+ from fusion_bench.models.utils import StateDictType, del_attr, get_attr, set_attr
25
26
  from fusion_bench.utils.type import StateDictType, TorchModelType
26
27
 
27
28
  log = logging.getLogger(__name__)
@@ -29,77 +30,7 @@ log = logging.getLogger(__name__)
29
30
  __all__ = ["get_task_wise_weights", "fuse_weights", "TaskWiseMergedModel"]
30
31
 
31
32
 
32
- def del_attr(obj, names: List[str]):
33
- """
34
- Deletes an attribute from an object recursively.
35
-
36
- Args:
37
- obj (object): Object to delete attribute from.
38
- names (list): List of attribute names to delete recursively.
39
- """
40
- if len(names) == 1:
41
- delattr(obj, names[0])
42
- else:
43
- del_attr(getattr(obj, names[0]), names[1:])
44
-
45
-
46
- def set_attr(obj, names: List[str], val):
47
- """
48
- Sets an attribute of an object recursively.
49
-
50
- Args:
51
- obj (object): Object to set attribute of.
52
- names (list): List of attribute names to set recursively.
53
- val (object): Value to set the attribute to.
54
- """
55
- if len(names) == 1:
56
- setattr(obj, names[0], val)
57
- else:
58
- set_attr(getattr(obj, names[0]), names[1:], val)
59
-
60
-
61
- def get_attr(obj, names: List[str]):
62
- """
63
- Gets an attribute of an object recursively.
64
-
65
- Args:
66
- obj (object): Object to get attribute of.
67
- names (list): List of attribute names to get recursively.
68
-
69
- Returns:
70
- object: The attribute of the object.
71
- """
72
- if len(names) == 1:
73
- return getattr(obj, names[0])
74
- else:
75
- return get_attr(getattr(obj, names[0]), names[1:])
76
-
77
-
78
- def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
79
- """
80
- Checks that the parameter names of the given checkpoints match.
81
-
82
- Args:
83
- checkpoints (List[Dict[str, float]]): A list of checkpoints, where each checkpoint is a dictionary of parameter names and their corresponding values.
84
-
85
- Raises:
86
- ValueError: If the number of checkpoints is less than 2 or if the parameter names of any two checkpoints differ.
87
-
88
- """
89
- parameter_names = set(checkpoints[0].keys())
90
-
91
- if len(checkpoints) >= 2:
92
- # raise ValueError("Number of models is less than 2.")
93
- for checkpoint in checkpoints[1:]:
94
- current_parameterNames = set(checkpoint.keys())
95
- if current_parameterNames != parameter_names:
96
- raise ValueError(
97
- "Differing parameter names in models. "
98
- f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
99
- )
100
-
101
-
102
- def get_task_wise_weights(num_models: int, init_values: float = None):
33
+ def get_task_wise_weights(num_models: int, init_values: float = None) -> Tensor:
103
34
  """
104
35
  This function generates a tensor of weights for each model.
105
36
 
@@ -116,7 +47,7 @@ def get_task_wise_weights(num_models: int, init_values: float = None):
116
47
  return torch.full((num_models,), init_values, dtype=torch.float32)
117
48
 
118
49
 
119
- def _fuse_weights(task_wise_weight: Tensor, tensors: List[Tensor]):
50
+ def _fuse_weights(task_wise_weight: Tensor, tensors: List[Tensor]) -> Tensor:
120
51
  """
121
52
  This function fuses the weights of the models.
122
53
 
@@ -158,6 +89,100 @@ def fuse_weights(
158
89
 
159
90
 
160
91
  class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
92
+ """
93
+ A PyTorch module that dynamically merges multiple fine-tuned models using learnable task-wise weights.
94
+
95
+ This class implements a sophisticated model fusion approach where multiple task-specific models
96
+ are combined with a pretrained base model using learnable weights. The fusion is performed
97
+ using task vectors (differences between fine-tuned and pretrained models) that are weighted
98
+ and added to the base model's parameters.
99
+
100
+ The key innovation is that the merging weights are learnable parameters that can be optimized
101
+ during training, allowing the model to automatically learn the optimal combination of different
102
+ task-specific knowledge.
103
+
104
+ Architecture:
105
+ - Base pretrained model (frozen)
106
+ - Multiple task vectors (differences from pretrained model, frozen)
107
+ - Learnable task-wise weights (trainable parameters)
108
+ - Dynamic merging during forward pass
109
+
110
+ Args:
111
+ task_wise_weight (Tensor): Initial weights for each task model. Shape: (num_models,).
112
+ These become learnable parameters that control the contribution of each task vector.
113
+ pretrained_model (TorchModelType): The base pretrained model that serves as the foundation.
114
+ This model is frozen and used as the starting point for merging.
115
+ finetuned_models (List[TorchModelType]): List of fine-tuned models for different tasks.
116
+ These are converted to task vectors (differences from pretrained model) and frozen.
117
+ clamp_weights (bool, optional): Whether to clamp merge weights to [0, 1] range.
118
+ Defaults to True. When True, ensures weights are non-negative and bounded.
119
+ tie_weights (bool, optional): Whether to tie weights during functional call.
120
+ Defaults to False. Used in the underlying PyTorch functional_call.
121
+ strict (bool, optional): Whether to enforce strict parameter matching.
122
+ Defaults to True. Used in the underlying PyTorch functional_call.
123
+ task_vector_dtype (Optional[torch.dtype], optional): Data type for task vectors.
124
+ Defaults to None. Can be used to save memory (e.g., torch.float16).
125
+
126
+ Attributes:
127
+ merge_weight (nn.Parameter): Learnable weights for merging task vectors.
128
+ pretrained_model (TorchModelType): The frozen base model.
129
+ task_vectors (nn.ModuleList): List of frozen task vector models.
130
+ _merged_state_dict (StateDictType): Cached merged state dictionary.
131
+
132
+ Example:
133
+ ```python
134
+ import torch
135
+ import torch.nn as nn
136
+
137
+ # Create example models
138
+ pretrained_model = nn.Linear(10, 5)
139
+ finetuned_model1 = nn.Linear(10, 5) # Fine-tuned on task 1
140
+ finetuned_model2 = nn.Linear(10, 5) # Fine-tuned on task 2
141
+
142
+ # Initialize task-wise weights
143
+ task_weights = torch.tensor([0.3, 0.7]) # Initial weights for 2 tasks
144
+
145
+ # Create merged model
146
+ merged_model = TaskWiseMergedModel(
147
+ task_wise_weight=task_weights,
148
+ pretrained_model=pretrained_model,
149
+ finetuned_models=[finetuned_model1, finetuned_model2],
150
+ clamp_weights=True
151
+ )
152
+
153
+ # Use like a regular PyTorch model
154
+ x = torch.randn(32, 10)
155
+ output = merged_model(x)
156
+
157
+ # Train the merge weights
158
+ optimizer = torch.optim.Adam(merged_model.parameters())
159
+ loss = some_loss_function(output, targets)
160
+ loss.backward()
161
+ optimizer.step()
162
+
163
+ # Get the final merged model
164
+ final_model = merged_model.merge_and_unload()
165
+ ```
166
+
167
+ Training Workflow:
168
+ 1. **Initialization**: Task vectors are computed as differences from pretrained model
169
+ 2. **Forward Pass**: Weights are dynamically merged based on current merge_weight values
170
+ 3. **Loss Computation**: Standard loss computation on model outputs
171
+ 4. **Backpropagation**: Gradients flow through merge_weight parameters
172
+ 5. **Optimization**: merge_weight parameters are updated to improve performance
173
+
174
+ Memory Efficiency:
175
+ - Task vectors can use lower precision (task_vector_dtype)
176
+ - Base model and task vectors are frozen (no gradient computation)
177
+ - Only merge weights require gradients
178
+
179
+ Note:
180
+ - The pretrained model and task vectors are frozen during training
181
+ - Only the merge weights (task_wise_weight) are trainable parameters
182
+ - Task vectors represent the difference between fine-tuned and pretrained models
183
+ - The merged state dict is cached and recomputed when merge weights change
184
+ """
185
+
161
186
  _merged_state_dict: StateDictType = None
162
187
 
163
188
  def __init__(
@@ -170,6 +195,32 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
170
195
  strict: bool = True,
171
196
  task_vector_dtype: Optional[torch.dtype] = None,
172
197
  ):
198
+ """
199
+ Initialize the TaskWiseMergedModel.
200
+
201
+ This constructor sets up the model by:
202
+ 1. Converting fine-tuned models to task vectors (differences from pretrained)
203
+ 2. Freezing the pretrained model and task vectors
204
+ 3. Setting up learnable merge weights as parameters
205
+ 4. Configuring merging behavior options
206
+
207
+ Args:
208
+ task_wise_weight (Tensor): Initial weights for each task model. Shape: (num_models,).
209
+ These values become the starting point for learnable parameters.
210
+ pretrained_model (TorchModelType): The base pretrained model.
211
+ Will be frozen and used as the foundation for merging.
212
+ finetuned_models (List[TorchModelType]): List of fine-tuned models.
213
+ Must have the same architecture as pretrained_model.
214
+ clamp_weights (bool, optional): Whether to clamp weights to [0, 1]. Defaults to True.
215
+ tie_weights (bool, optional): Whether to tie weights in functional_call. Defaults to False.
216
+ strict (bool, optional): Whether to use strict parameter matching. Defaults to True.
217
+ task_vector_dtype (Optional[torch.dtype], optional): Data type for task vectors.
218
+ Defaults to None (same as original models).
219
+
220
+ Raises:
221
+ ValueError: If the number of task_wise_weights doesn't match the number of fine-tuned models.
222
+ RuntimeError: If models have incompatible architectures.
223
+ """
173
224
  super().__init__()
174
225
  self.clamp_weights = clamp_weights
175
226
  self.tie_weights = tie_weights
@@ -196,6 +247,24 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
196
247
 
197
248
  @property
198
249
  def forward_model(self):
250
+ """
251
+ Get a functional model with merged parameters.
252
+
253
+ Returns a partial function that applies the pretrained model with the current
254
+ merged state dictionary. This allows for efficient forward passes without
255
+ modifying the original model's parameters.
256
+
257
+ Returns:
258
+ Callable: A partial function that can be called with (args, kwargs) to
259
+ perform forward pass with merged parameters.
260
+
261
+ Example:
262
+ ```python
263
+ # Internal usage during forward pass
264
+ forward_fn = merged_model.forward_model
265
+ output = forward_fn(args=(x,), kwargs={})
266
+ ```
267
+ """
199
268
  return functools.partial(
200
269
  functional_call,
201
270
  self.pretrained_model,
@@ -205,6 +274,43 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
205
274
  )
206
275
 
207
276
  def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
277
+ """
278
+ Merge task vectors with the pretrained model using current merge weights.
279
+
280
+ This method computes the merged model parameters by combining the pretrained
281
+ model with weighted task vectors. The resulting state dictionary represents
282
+ a model that incorporates knowledge from all task-specific models.
283
+
284
+ The merging formula for each parameter is:
285
+ merged_param = pretrained_param + Σ(weight_i * task_vector_i * mask_i)
286
+
287
+ Args:
288
+ task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
289
+ to selectively apply task vectors to specific parameters. Keys should
290
+ match parameter names, values should be tensors with the same shape
291
+ as the corresponding parameters. Defaults to None (no masking).
292
+
293
+ Returns:
294
+ StateDictType: The merged state dictionary containing combined parameters.
295
+
296
+ Example:
297
+ ```python
298
+ # Basic merging
299
+ merged_state = model.merge_weights()
300
+
301
+ # Merging with parameter-specific masks
302
+ masks = {
303
+ 'layer1.weight': torch.ones_like(model.pretrained_model.layer1.weight),
304
+ 'layer2.weight': torch.zeros_like(model.pretrained_model.layer2.weight),
305
+ }
306
+ masked_state = model.merge_weights(task_vector_mask=masks)
307
+ ```
308
+
309
+ Note:
310
+ - If clamp_weights is True, merge weights are clamped to [0, 1] range
311
+ - The merged state dict is cached in _merged_state_dict
312
+ - Task vector masks allow fine-grained control over which parameters are affected
313
+ """
208
314
  if self.clamp_weights:
209
315
  merge_weight = self.merge_weight.clamp(0, 1)
210
316
  else:
@@ -222,11 +328,83 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
222
328
  return state_dict
223
329
 
224
330
  def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
331
+ """
332
+ Merge models and return the final merged model.
333
+
334
+ This method performs the merging operation and then loads the merged parameters
335
+ into the pretrained model, returning a standard PyTorch model that can be used
336
+ independently of the TaskWiseMergedModel wrapper.
337
+
338
+ Args:
339
+ task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
340
+ for selective parameter merging. Defaults to None.
341
+
342
+ Returns:
343
+ TorchModelType: The pretrained model with merged parameters loaded.
344
+ This is a standalone model that can be used without the wrapper.
345
+
346
+ Example:
347
+ ```python
348
+ # Train the merged model
349
+ for epoch in range(num_epochs):
350
+ # ... training loop ...
351
+ pass
352
+
353
+ # Get the final merged model
354
+ final_model = merged_model.merge_and_unload()
355
+
356
+ # Save or use the final model
357
+ torch.save(final_model.state_dict(), 'merged_model.pth')
358
+ output = final_model(new_input)
359
+ ```
360
+
361
+ Warning:
362
+ This method modifies the pretrained_model's parameters in-place.
363
+ The original pretrained model parameters will be lost.
364
+ """
225
365
  self.merge_weights(task_vector_mask=task_vector_mask)
226
366
  self.pretrained_model.load_state_dict(self._merged_state_dict)
227
367
  return self.pretrained_model
228
368
 
229
369
  def forward(self, *args, **kwargs):
370
+ """
371
+ Forward pass through the dynamically merged model.
372
+
373
+ This method performs the forward pass by first ensuring the model parameters
374
+ are merged according to the current merge weights, then applying the merged
375
+ model to the input data.
376
+
377
+ The forward pass involves:
378
+ 1. Check if merged state dict is current (recompute if needed)
379
+ 2. Apply the merged model to inputs using functional_call
380
+ 3. Return the model outputs
381
+
382
+ Args:
383
+ *args: Positional arguments to pass to the underlying model.
384
+ **kwargs: Keyword arguments to pass to the underlying model.
385
+
386
+ Returns:
387
+ Any: The output of the merged model, typically torch.Tensor or tuple of tensors.
388
+
389
+ Example:
390
+ ```python
391
+ # Single input
392
+ x = torch.randn(32, 784)
393
+ output = merged_model(x)
394
+
395
+ # Multiple inputs
396
+ x1, x2 = torch.randn(32, 784), torch.randn(32, 100)
397
+ output = merged_model(x1, x2)
398
+
399
+ # With keyword arguments
400
+ output = merged_model(input_ids=input_ids, attention_mask=attention_mask)
401
+ ```
402
+
403
+ Note:
404
+ - The merged state dict is recomputed if merge weights have changed
405
+ - This allows for dynamic behavior during training as weights are updated
406
+ - The computation is efficient as merging only happens when needed
407
+ """
230
408
  if self._merged_state_dict is None:
231
409
  self.merge_weights()
232
410
  return self.forward_model(args=args, kwargs=kwargs)
@@ -1,9 +1,88 @@
1
+ """
2
+ Base Program Classes for FusionBench.
3
+
4
+ This module defines the foundational abstract base classes for FusionBench programs.
5
+ These programs serve as the main execution units that orchestrate model fusion
6
+ workflows, from loading configurations to executing fusion algorithms and
7
+ evaluating results.
8
+
9
+ The base classes provide a consistent interface for all FusionBench programs
10
+ while allowing for flexible implementations of different fusion workflows.
11
+ """
12
+
1
13
  from abc import abstractmethod
2
14
 
3
- from fusion_bench.mixins import BaseYAMLSerializableModel
15
+ from fusion_bench.mixins import BaseYAMLSerializable
16
+
17
+
18
+ class BaseHydraProgram(BaseYAMLSerializable):
19
+ """
20
+ Abstract base class for all FusionBench programs that use Hydra configuration.
21
+
22
+ This class serves as the foundation for all FusionBench execution programs,
23
+ providing a standardized interface for configuration-driven model fusion
24
+ workflows. It combines the serialization capabilities of BaseYAMLSerializable
25
+ with the requirement for a main execution method.
26
+
27
+ The class is designed to work seamlessly with Hydra's configuration management
28
+ system, allowing programs to be instantiated and configured through YAML files.
29
+ This enables flexible, reproducible experiments with different fusion algorithms,
30
+ model pools, and evaluation tasks.
31
+
32
+ Key Features:
33
+
34
+ - Configuration-driven execution through Hydra integration
35
+ - YAML serialization support for experiment reproducibility
36
+ - Abstract interface ensuring consistent program structure
37
+ - Integration with FusionBench's modular architecture
4
38
 
39
+ Typical Usage:
40
+ Subclasses should implement the `run()` method to define their specific
41
+ fusion workflow. The program can then be executed through the FusionBench
42
+ CLI or instantiated directly from configuration files.
43
+
44
+ Example:
45
+ ```python
46
+ class MyFusionProgram(BaseHydraProgram):
47
+ def __init__(self, method_config, modelpool_config, taskpool_config):
48
+ self.method_config = method_config
49
+ self.modelpool_config = modelpool_config
50
+ self.taskpool_config = taskpool_config
51
+
52
+ def run(self):
53
+ # Load components
54
+ algorithm = load_algorithm(self.method_config)
55
+ modelpool = load_modelpool(self.modelpool_config)
56
+ taskpool = load_taskpool(self.taskpool_config)
57
+
58
+ # Execute fusion
59
+ merged_model = algorithm.run(modelpool)
60
+
61
+ # Evaluate results
62
+ report = taskpool.evaluate(merged_model)
63
+ return report
64
+ ```
65
+
66
+ Note:
67
+ This is an abstract base class and cannot be instantiated directly.
68
+ Subclasses must implement the `run()` method to provide concrete
69
+ functionality.
70
+
71
+ See Also:
72
+
73
+ - [FabricModelFusionProgram][fusion_bench.programs.FabricModelFusionProgram]: Lightning Fabric-based implementation
74
+ - [BaseYAMLSerializable][fusion_bench.mixins.BaseYAMLSerializable]: Parent class providing serialization
75
+ - FusionBench CLI documentation for program execution details
76
+ """
5
77
 
6
- class BaseHydraProgram(BaseYAMLSerializableModel):
7
78
  @abstractmethod
8
79
  def run(self):
80
+ """
81
+ Execute the main program workflow.
82
+
83
+ This abstract method defines the primary entry point for program execution.
84
+ Subclasses must implement this method to define their specific fusion
85
+ workflow, including model loading, fusion algorithm execution, and
86
+ result evaluation.
87
+ """
9
88
  pass
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import logging
3
3
  import os
4
- from typing import Callable, Dict, Iterable, Optional, Union # noqa: F401
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
5
6
 
6
7
  import lightning as L
7
8
  from lightning.fabric.utilities.rank_zero import rank_zero_only
@@ -9,19 +10,24 @@ from omegaconf import DictConfig, OmegaConf
9
10
  from torch import nn
10
11
  from tqdm.auto import tqdm
11
12
 
12
- import fusion_bench.utils.instantiate_utils
13
- from fusion_bench.method import BaseAlgorithm
13
+ import fusion_bench
14
+ from fusion_bench import (
15
+ BaseAlgorithm,
16
+ BaseHydraProgram,
17
+ BaseModelPool,
18
+ BaseTaskPool,
19
+ RuntimeConstants,
20
+ import_object,
21
+ instantiate,
22
+ timeit_context,
23
+ )
14
24
  from fusion_bench.mixins import LightningFabricMixin
15
- from fusion_bench.modelpool import BaseModelPool
16
- from fusion_bench.programs import BaseHydraProgram
17
- from fusion_bench.taskpool import BaseTaskPool
18
- from fusion_bench.utils import import_object, instantiate, timeit_context
19
25
  from fusion_bench.utils.hydra_utils import get_hydra_output_dir
20
26
  from fusion_bench.utils.json import print_json
27
+ from fusion_bench.utils.path import create_symlink
21
28
  from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
22
- from fusion_bench.utils.pylogger import getRankZeroLogger
23
29
 
24
- log = getRankZeroLogger(__name__)
30
+ log = fusion_bench.get_rankzero_logger(__name__)
25
31
 
26
32
 
27
33
  class FabricModelFusionProgram(
@@ -39,6 +45,7 @@ class FabricModelFusionProgram(
39
45
  "_fabric": "fabric",
40
46
  "fast_dev_run": "fast_dev_run",
41
47
  "seed": "seed",
48
+ "path": "path",
42
49
  }
43
50
 
44
51
  def __init__(
@@ -56,8 +63,10 @@ class FabricModelFusionProgram(
56
63
  fast_dev_run: bool = False,
57
64
  seed: Optional[int] = None,
58
65
  print_function_call: bool = True,
66
+ path: DictConfig = None,
59
67
  **kwargs,
60
68
  ):
69
+ super().__init__(**kwargs)
61
70
  self._method = method
62
71
  self._modelpool = modelpool
63
72
  self._taskpool = taskpool
@@ -67,8 +76,11 @@ class FabricModelFusionProgram(
67
76
  self.merged_model_save_kwargs = merged_model_save_kwargs
68
77
  self.fast_dev_run = fast_dev_run
69
78
  self.seed = seed
70
- fusion_bench.utils.instantiate_utils.PRINT_FUNCTION_CALL = print_function_call
71
- super().__init__(**kwargs)
79
+ self.path = path
80
+ RuntimeConstants.debug = fast_dev_run
81
+ RuntimeConstants.print_function_call = print_function_call
82
+ if path is not None:
83
+ RuntimeConstants.cache_dir = path.get("cache_dir", None)
72
84
 
73
85
  if print_config:
74
86
  print_config_tree(
@@ -164,9 +176,9 @@ class FabricModelFusionProgram(
164
176
  self,
165
177
  taskpool: BaseTaskPool,
166
178
  merged_model: Union[nn.Module, Dict, Iterable],
167
- *args,
168
- **kwargs,
169
- ):
179
+ *args: Any,
180
+ **kwargs: Any,
181
+ ) -> Union[Dict, List, Any]:
170
182
  """
171
183
  Evaluates the merged model using the provided task pool.
172
184
 
@@ -221,8 +233,16 @@ class FabricModelFusionProgram(
221
233
  fabric = self.fabric
222
234
  if self.seed is not None:
223
235
  L.seed_everything(self.seed)
224
- if fabric.global_rank == 0:
225
- self._link_hydra_output()
236
+
237
+ # create symbol link to hydra output directory
238
+ if (
239
+ self.fabric.is_global_zero
240
+ and self.log_dir is not None
241
+ and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
242
+ ):
243
+ create_symlink(
244
+ get_hydra_output_dir(), self.log_dir, link_name="hydra_output"
245
+ )
226
246
 
227
247
  log.info("Running the model fusion program.")
228
248
  # setup the modelpool, method, and taskpool
@@ -243,7 +263,10 @@ class FabricModelFusionProgram(
243
263
  compat_load_fn="fusion_bench.compat.taskpool.load_taskpool_from_config",
244
264
  )
245
265
 
266
+ self.method.on_run_start()
246
267
  merged_model = self.method.run(self.modelpool)
268
+ self.method.on_run_end()
269
+
247
270
  if merged_model is None:
248
271
  log.info(
249
272
  "No merged model returned by the method. Skipping saving and evaluation."
@@ -261,52 +284,14 @@ class FabricModelFusionProgram(
261
284
  if self.report_save_path is not None:
262
285
  # save report (Dict) to a file
263
286
  # if the directory of `save_report` does not exists, create it
264
- if "{log_dir}" in self.report_save_path and self.log_dir is not None:
265
- self.report_save_path = self.report_save_path.format(log_dir=self.log_dir)
287
+ if (
288
+ "{log_dir}" in self.report_save_path
289
+ and self.log_dir is not None
290
+ ):
291
+ self.report_save_path = self.report_save_path.format(
292
+ log_dir=self.log_dir
293
+ )
266
294
  os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
267
295
  json.dump(report, open(self.report_save_path, "w"))
268
296
  else:
269
297
  log.info("No task pool specified. Skipping evaluation.")
270
-
271
- @rank_zero_only
272
- def _link_hydra_output(self):
273
- """
274
- Creates a symbolic link to the Hydra output directory within the specified log directory.
275
-
276
- If `self.log_dir` is not None, this method will:
277
- 1. Retrieve the Hydra output directory using `get_hydra_output_dir()`.
278
- 2. Create the log directory if it does not already exist.
279
- 3. Create a symbolic link named "hydra_output_<basename_of_hydra_output_dir>"
280
- within the log directory, pointing to the Hydra output directory.
281
-
282
- Note:
283
- - The symbolic link is created only if the Hydra output directory is not None.
284
- - The `target_is_directory` parameter is set to True to indicate that the target is a directory.
285
-
286
- Raises:
287
- OSError: If the symbolic link creation fails.
288
- """
289
- if self.log_dir is not None:
290
- # make symlink to the hydra output directory
291
- try:
292
- hydra_output_dir = get_hydra_output_dir()
293
- except Exception as e:
294
- hydra_output_dir = None
295
-
296
- if hydra_output_dir is not None:
297
- os.makedirs(self.log_dir, exist_ok=True)
298
- try:
299
- # if the system is windows, use the `mklink` command in "CMD" to create the symlink
300
- if os.name == "nt":
301
- os.system(f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}")
302
- else:
303
- os.symlink(
304
- hydra_output_dir,
305
- os.path.join(
306
- self.log_dir,
307
- "hydra_output_" + os.path.basename(hydra_output_dir),
308
- ),
309
- target_is_directory=True,
310
- )
311
- except OSError as e:
312
- log.warning(f"Failed to create symbolic link: {e}")