fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. fusion_bench/__init__.py +1 -0
  2. fusion_bench/_get_started/__init__.py +3 -0
  3. fusion_bench/_get_started/greeting_program.py +49 -0
  4. fusion_bench/compat/method/base_algorithm.py +14 -0
  5. fusion_bench/constants/__init__.py +5 -0
  6. fusion_bench/constants/clip_vision.py +26 -2
  7. fusion_bench/constants/paths.py +4 -0
  8. fusion_bench/dataset/clip_dataset.py +2 -1
  9. fusion_bench/dataset/gpt2_glue.py +9 -9
  10. fusion_bench/dataset/image_corruption/__init__.py +0 -0
  11. fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
  12. fusion_bench/dataset/image_dataset.py +1 -1
  13. fusion_bench/dataset/nyuv2.py +2 -2
  14. fusion_bench/method/__init__.py +16 -1
  15. fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
  16. fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
  17. fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
  18. fusion_bench/method/base_algorithm.py +195 -12
  19. fusion_bench/method/bitdelta/__init__.py +4 -0
  20. fusion_bench/method/bitdelta/bitdelta.py +156 -0
  21. fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
  22. fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
  23. fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
  24. fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
  25. fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
  26. fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
  27. fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
  28. fusion_bench/method/doge_ta/doge_ta.py +1 -1
  29. fusion_bench/method/ensemble.py +12 -12
  30. fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
  31. fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
  32. fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
  33. fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
  34. fusion_bench/method/fw_merging/fw_hard.py +1 -1
  35. fusion_bench/method/fw_merging/fw_soft.py +1 -1
  36. fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
  37. fusion_bench/method/linear/expo.py +2 -1
  38. fusion_bench/method/linear/linear_interpolation.py +6 -4
  39. fusion_bench/method/linear/simple_average_for_llama.py +16 -6
  40. fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
  41. fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
  42. fusion_bench/method/model_recombination.py +2 -5
  43. fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
  44. fusion_bench/method/moe_pruner/utils/data.py +2 -1
  45. fusion_bench/method/moe_pruner/utils/prune.py +6 -1
  46. fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
  47. fusion_bench/method/pruning/wanda_utils/data.py +1 -2
  48. fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
  49. fusion_bench/method/randes/modelsoup.py +1 -3
  50. fusion_bench/method/regmean/clip_regmean.py +2 -2
  51. fusion_bench/method/regmean/gpt2_regmean.py +3 -10
  52. fusion_bench/method/regmean/regmean.py +2 -11
  53. fusion_bench/method/regmean_plusplus/__init__.py +3 -0
  54. fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
  55. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
  56. fusion_bench/method/simple_average.py +16 -4
  57. fusion_bench/method/slerp/slerp.py +5 -2
  58. fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
  59. fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
  60. fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
  61. fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
  62. fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
  63. fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
  64. fusion_bench/method/ties_merging/ties_merging.py +13 -26
  65. fusion_bench/method/we_moe/clip_we_moe.py +5 -4
  66. fusion_bench/method/we_moe/we_moe.py +6 -6
  67. fusion_bench/method/weighted_average/llama.py +4 -16
  68. fusion_bench/metrics/continual_learning/__init__.py +1 -0
  69. fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
  70. fusion_bench/metrics/nyuv2/__init__.py +2 -2
  71. fusion_bench/metrics/nyuv2/segmentation.py +1 -1
  72. fusion_bench/mixins/__init__.py +10 -2
  73. fusion_bench/mixins/clip_classification.py +4 -3
  74. fusion_bench/mixins/hydra_config.py +105 -7
  75. fusion_bench/mixins/lightning_fabric.py +2 -0
  76. fusion_bench/mixins/serialization.py +265 -48
  77. fusion_bench/modelpool/__init__.py +2 -2
  78. fusion_bench/modelpool/base_pool.py +29 -9
  79. fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
  80. fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
  81. fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
  82. fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
  83. fusion_bench/models/__init__.py +2 -1
  84. fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
  85. fusion_bench/models/hf_utils.py +182 -0
  86. fusion_bench/models/linearized/linearized_model_utils.py +4 -4
  87. fusion_bench/models/linearized/vision_model.py +1 -1
  88. fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
  89. fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
  90. fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
  91. fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
  92. fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
  93. fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
  94. fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
  95. fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
  96. fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
  97. fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
  98. fusion_bench/models/modeling_smile_llama/register.py +8 -0
  99. fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
  100. fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
  101. fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
  102. fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
  103. fusion_bench/models/parameter_dict.py +1 -1
  104. fusion_bench/models/sparse_we_moe.py +1 -53
  105. fusion_bench/models/utils.py +26 -0
  106. fusion_bench/models/we_moe.py +1 -53
  107. fusion_bench/models/wrappers/ensemble.py +6 -4
  108. fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
  109. fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
  110. fusion_bench/programs/base_program.py +81 -2
  111. fusion_bench/programs/fabric_fusion_program.py +24 -8
  112. fusion_bench/scripts/cli.py +6 -6
  113. fusion_bench/taskpool/base_pool.py +4 -3
  114. fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
  115. fusion_bench/taskpool/dummy.py +1 -1
  116. fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
  117. fusion_bench/tasks/clip_classification/__init__.py +6 -4
  118. fusion_bench/utils/__init__.py +6 -1
  119. fusion_bench/utils/devices.py +14 -4
  120. fusion_bench/utils/instantiate_utils.py +3 -1
  121. fusion_bench/utils/misc.py +48 -2
  122. fusion_bench/utils/modelscope.py +265 -0
  123. fusion_bench/utils/parameters.py +2 -2
  124. fusion_bench/utils/rich_utils.py +3 -0
  125. fusion_bench/utils/state_dict_arithmetic.py +34 -27
  126. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
  127. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
  128. fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
  129. fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
  130. fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
  131. fusion_bench_config/_get_started/greeting_program.yaml +4 -0
  132. fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
  133. fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
  134. fusion_bench_config/fabric_model_fusion.yaml +45 -17
  135. fusion_bench_config/hydra/default.yaml +6 -2
  136. fusion_bench_config/llama_full_finetune.yaml +1 -0
  137. fusion_bench_config/method/adamerging/clip.yaml +1 -1
  138. fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
  139. fusion_bench_config/method/depth_upscaling.yaml +4 -1
  140. fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
  141. fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
  142. fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
  143. fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
  144. fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
  145. fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
  146. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
  147. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
  148. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
  149. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
  150. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
  151. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
  152. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
  153. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
  154. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
  155. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
  156. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
  157. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
  158. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
  159. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
  160. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
  161. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
  162. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
  163. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
  164. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
  165. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
  166. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
  167. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
  168. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
  169. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
  170. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
  171. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
  172. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
  173. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
  174. fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
  175. fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
  176. fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
  177. fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
  178. fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
  179. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
  180. fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
  181. fusion_bench_config/nyuv2_config.yaml +3 -1
  182. fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
  183. fusion_bench_config/path/default.yaml +28 -0
  184. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
  185. fusion_bench_config/method/adamerging.yaml +0 -23
  186. fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
  187. fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
  188. fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
  189. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
  190. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
  191. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
  192. {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
  193. /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
@@ -9,14 +9,14 @@ Overview of Ties-Merging:
9
9
  """
10
10
 
11
11
  import logging
12
- from typing import Dict, List, Literal, Mapping, Union # noqa: F401
12
+ from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
13
13
 
14
14
  import torch
15
15
  from torch import Tensor, nn
16
16
 
17
17
  from fusion_bench.compat.modelpool import to_modelpool
18
18
  from fusion_bench.method import BaseAlgorithm
19
- from fusion_bench.mixins import SimpleProfilerMixin
19
+ from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
20
20
  from fusion_bench.modelpool import BaseModelPool
21
21
  from fusion_bench.utils.type import StateDictType
22
22
 
@@ -25,33 +25,22 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
25
25
  log = logging.getLogger(__name__)
26
26
 
27
27
 
28
- class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
29
- """
30
- TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
31
-
32
- Attributes:
33
- scaling_factor (float): The scaling factor to apply to the merged task vector.
34
- threshold (float): The threshold for resetting values in the task vector.
35
- remove_keys (List[str]): List of keys to remove from the state dictionary.
36
- merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
37
- """
38
-
39
- _config_mapping = BaseAlgorithm._config_mapping | {
40
- "scaling_factor": "scaling_factor",
41
- "threshold": "threshold",
42
- "remove_keys": "remove_keys",
43
- "merge_func": "merge_func",
44
- }
45
-
28
+ @auto_register_config
29
+ class TiesMergingAlgorithm(
30
+ SimpleProfilerMixin,
31
+ BaseAlgorithm,
32
+ ):
46
33
  def __init__(
47
34
  self,
48
35
  scaling_factor: float,
49
36
  threshold: float,
50
37
  remove_keys: List[str],
51
38
  merge_func: Literal["sum", "mean", "max"],
52
- **kwargs,
39
+ **kwargs: Any,
53
40
  ):
54
41
  """
42
+ TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
43
+
55
44
  Initialize the TiesMergingAlgorithm with the given parameters.
56
45
 
57
46
  Args:
@@ -61,14 +50,12 @@ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
61
50
  merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
62
51
  **kwargs: Additional keyword arguments for the base class.
63
52
  """
64
- self.scaling_factor = scaling_factor
65
- self.threshold = threshold
66
- self.remove_keys = remove_keys
67
- self.merge_func = merge_func
68
53
  super().__init__(**kwargs)
69
54
 
70
55
  @torch.no_grad()
71
- def run(self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs):
56
+ def run(
57
+ self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs: Any
58
+ ) -> nn.Module:
72
59
  """
73
60
  Run the TIES merging algorithm to fuse models in the model pool.
74
61
 
@@ -2,6 +2,7 @@ import functools
2
2
  import logging
3
3
  import os
4
4
  from copy import deepcopy
5
+ from typing import Any, Iterator
5
6
 
6
7
  import torch
7
8
  from torch import Tensor
@@ -38,7 +39,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
38
39
 
39
40
  modelpool: CLIPVisionModelPool = None
40
41
 
41
- def load_checkpoint(self, model, checkpoint):
42
+ def load_checkpoint(self, model: Any, checkpoint: Any):
42
43
  """
43
44
  Load the checkpoint file.
44
45
 
@@ -49,7 +50,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
49
50
  state = {"model": model}
50
51
  self._fabric.load(checkpoint, state)
51
52
 
52
- def save_checkpoint(self, model, checkpoint):
53
+ def save_checkpoint(self, model: Any, checkpoint: Any):
53
54
  """
54
55
  Save the checkpoint file.
55
56
 
@@ -102,7 +103,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
102
103
  return moe_model
103
104
 
104
105
  @functools.cache
105
- def get_shuffled_test_loader_iter(self, tta_dataset: str):
106
+ def get_shuffled_test_loader_iter(self, tta_dataset: str) -> Iterator:
106
107
  """
107
108
  Get an iterator for the shuffled test data loader.
108
109
 
@@ -131,7 +132,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
131
132
  """
132
133
  self.setup_zero_shot_classification_head()
133
134
 
134
- def compute_logits(self, module, batch, task) -> Tensor:
135
+ def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
135
136
  """
136
137
  Compute the logits for the given batch and task.
137
138
 
@@ -1,6 +1,6 @@
1
1
  import logging
2
2
  from abc import abstractmethod
3
- from typing import cast # noqa: F401
3
+ from typing import Any, cast # noqa: F401
4
4
 
5
5
  import lightning as L
6
6
  import lightning.fabric.wrappers
@@ -70,7 +70,7 @@ class WeightEnsemblingMoEAlgorithm(
70
70
  assert "No CUDA device available."
71
71
 
72
72
  @abstractmethod
73
- def load_checkpoint(self, model, checkpoint):
73
+ def load_checkpoint(self, model: Any, checkpoint: Any):
74
74
  """
75
75
  Load the checkpoint file.
76
76
 
@@ -81,7 +81,7 @@ class WeightEnsemblingMoEAlgorithm(
81
81
  pass
82
82
 
83
83
  @abstractmethod
84
- def save_checkpoint(self, model, checkpoint):
84
+ def save_checkpoint(self, model: Any, checkpoint: Any):
85
85
  """
86
86
  Save the checkpoint file.
87
87
 
@@ -121,7 +121,7 @@ class WeightEnsemblingMoEAlgorithm(
121
121
  pass
122
122
 
123
123
  @abstractmethod
124
- def compute_logits(self, module, batch, task) -> Tensor:
124
+ def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
125
125
  """
126
126
  Compute the logits for a given batch and task.
127
127
 
@@ -135,7 +135,7 @@ class WeightEnsemblingMoEAlgorithm(
135
135
  """
136
136
  pass
137
137
 
138
- def test_time_adaptation(self, module: WeightEnsemblingMoE):
138
+ def test_time_adaptation(self, module: WeightEnsemblingMoE) -> WeightEnsemblingMoE:
139
139
  """
140
140
  Perform test-time adaptation for the given module.
141
141
 
@@ -208,7 +208,7 @@ class WeightEnsemblingMoEAlgorithm(
208
208
 
209
209
  return module
210
210
 
211
- def run(self, modelpool: ModelPool):
211
+ def run(self, modelpool: ModelPool) -> WeightEnsemblingMoE:
212
212
  """
213
213
  Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
214
214
 
@@ -3,6 +3,7 @@ from typing import List, Mapping, Union # noqa: F401
3
3
 
4
4
  import numpy as np
5
5
  import torch
6
+ from transformers import PreTrainedModel
6
7
  from typing_extensions import override
7
8
 
8
9
  from fusion_bench.method import BaseAlgorithm
@@ -10,24 +11,17 @@ from fusion_bench.modelpool import CausalLMPool
10
11
  from fusion_bench.utils import timeit_context
11
12
  from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
12
13
  from fusion_bench.utils.type import StateDictType
14
+ from fusion_bench.mixins import auto_register_config
13
15
 
14
16
  log = logging.getLogger(__name__)
15
17
 
16
18
 
19
+ @auto_register_config
17
20
  class WeightedAverageForLLama(BaseAlgorithm):
18
21
  """
19
22
  A class to perform weighted averaging of LlaMa/Mistral models.
20
23
  """
21
24
 
22
- _config_mapping = BaseAlgorithm._config_mapping | {
23
- "normalize": "normalize",
24
- "weights": "weights",
25
- "backbone_only": "backbone_only",
26
- "merged_model_save_path": "merged_model_save_path",
27
- "save_tokenizer": "save_tokenizer",
28
- "push_to_hub": "push_to_hub",
29
- }
30
-
31
25
  def __init__(
32
26
  self,
33
27
  normalize: bool,
@@ -49,17 +43,11 @@ class WeightedAverageForLLama(BaseAlgorithm):
49
43
  save_tokenizer (bool): Whether to save the tokenizer.
50
44
  push_to_hub (bool): Whether to push the model to the hub.
51
45
  """
52
- self.normalize = normalize
53
- self.weights = weights
54
- self.backbone_only = backbone_only
55
- self.merged_model_save_path = merged_model_save_path
56
- self.save_tokenizer = save_tokenizer
57
- self.push_to_hub = push_to_hub
58
46
  super().__init__(**kwargs)
59
47
 
60
48
  @override
61
49
  @torch.no_grad()
62
- def run(self, modelpool: CausalLMPool):
50
+ def run(self, modelpool: CausalLMPool) -> PreTrainedModel:
63
51
  """
64
52
  Executes the weighted averaging of models in the provided model pool.
65
53
 
@@ -0,0 +1 @@
1
+ from .backward_transfer import compute_backward_transfer
@@ -10,7 +10,7 @@ def compute_backward_transfer(
10
10
  Compute the backward transfer (BWT) of a model on a set of tasks.
11
11
 
12
12
  Equation:
13
- BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{Ti}[k] - acc_{ii}[k])
13
+ $BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{T,i}[k] - acc_{i,i}[k])$
14
14
 
15
15
  Returns:
16
16
  float: The backward transfer of the model.
@@ -1,10 +1,10 @@
1
1
  from .depth import DepthMetric
2
2
  from .noise import NoiseMetric
3
3
  from .normal import NormalMetric
4
- from .segmentation import SegmentationMertic
4
+ from .segmentation import SegmentationMetric
5
5
 
6
6
  metric_classes = {
7
- "segmentation": SegmentationMertic,
7
+ "segmentation": SegmentationMetric,
8
8
  "depth": DepthMetric,
9
9
  "normal": NormalMetric,
10
10
  "noise": NoiseMetric,
@@ -5,7 +5,7 @@ from torch import Tensor, nn
5
5
  from torchmetrics import Metric
6
6
 
7
7
 
8
- class SegmentationMertic(Metric):
8
+ class SegmentationMetric(Metric):
9
9
  metric_names = ["mIoU", "pixAcc"]
10
10
 
11
11
  def __init__(self, num_classes=13):
@@ -11,7 +11,11 @@ _import_structure = {
11
11
  "hydra_config": ["HydraConfigMixin"],
12
12
  "lightning_fabric": ["LightningFabricMixin"],
13
13
  "openclip_classification": ["OpenCLIPClassificationMixin"],
14
- "serialization": ["YAMLSerializationMixin", "BaseYAMLSerializableModel"],
14
+ "serialization": [
15
+ "BaseYAMLSerializable",
16
+ "YAMLSerializationMixin",
17
+ "auto_register_config",
18
+ ],
15
19
  "simple_profiler": ["SimpleProfilerMixin"],
16
20
  }
17
21
 
@@ -21,7 +25,11 @@ if TYPE_CHECKING:
21
25
  from .hydra_config import HydraConfigMixin
22
26
  from .lightning_fabric import LightningFabricMixin
23
27
  from .openclip_classification import OpenCLIPClassificationMixin
24
- from .serialization import BaseYAMLSerializableModel, YAMLSerializationMixin
28
+ from .serialization import (
29
+ BaseYAMLSerializable,
30
+ YAMLSerializationMixin,
31
+ auto_register_config,
32
+ )
25
33
  from .simple_profiler import SimpleProfilerMixin
26
34
  else:
27
35
  sys.modules[__name__] = LazyImporter(
@@ -6,6 +6,7 @@ from typing import ( # noqa: F401
6
6
  TYPE_CHECKING,
7
7
  Any,
8
8
  Dict,
9
+ Iterator,
9
10
  List,
10
11
  Optional,
11
12
  Tuple,
@@ -48,7 +49,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
48
49
  - `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
49
50
  """
50
51
 
51
- _dataloader_kwargs: Dict[str, Any] = {}
52
+ dataloader_kwargs: Dict[str, Any] = {}
52
53
  # the modelpool is set by inheriting class
53
54
  modelpool: CLIPVisionModelPool = None
54
55
  _clip_processor: CLIPProcessor = None
@@ -71,7 +72,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
71
72
  batch_size: Optional[int] = None,
72
73
  num_workers: Optional[int] = None,
73
74
  **loader_kwargs,
74
- ):
75
+ ) -> Iterator:
75
76
  """
76
77
  Get an iterator for a shuffled test DataLoader.
77
78
 
@@ -89,7 +90,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
89
90
  Iterator: An iterator over the shuffled test DataLoader.
90
91
  """
91
92
  # get dataloader kwargs
92
- dataloader_kwargs = self._dataloader_kwargs.copy()
93
+ dataloader_kwargs = self.dataloader_kwargs.copy()
93
94
  dataloader_kwargs["shuffle"] = True
94
95
  if batch_size is not None:
95
96
  dataloader_kwargs["batch_size"] = batch_size
@@ -1,8 +1,20 @@
1
+ """
2
+ Hydra Configuration Mixin for FusionBench.
3
+
4
+ This module provides a mixin class that enables easy instantiation of objects
5
+ from Hydra configuration files. It's designed to work seamlessly with the
6
+ FusionBench configuration system and supports dynamic object creation based
7
+ on YAML configuration files.
8
+
9
+ The mixin integrates with Hydra's configuration management system to provide
10
+ a clean interface for creating objects from structured configurations.
11
+ """
12
+
1
13
  import logging
2
14
  import os
3
15
  from copy import deepcopy
4
16
  from pathlib import Path
5
- from typing import Dict, List, Optional, Union
17
+ from typing import Dict, List, Optional, TypeVar, Union
6
18
 
7
19
  import hydra.core.global_hydra
8
20
  from hydra import compose, initialize
@@ -13,10 +25,39 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
13
25
 
14
26
  log = logging.getLogger(__name__)
15
27
 
28
+ T = TypeVar("T", bound="HydraConfigMixin")
29
+
16
30
 
17
31
  class HydraConfigMixin:
18
- """
19
- A mixin for classes that need to be instantiated from a config file.
32
+ R"""
33
+ A mixin class that provides configuration-based instantiation capabilities.
34
+
35
+ This mixin enables classes to be instantiated directly from Hydra configuration
36
+ files, supporting both direct instantiation and target-based instantiation patterns.
37
+ It's particularly useful in FusionBench for creating model pools, task pools,
38
+ and fusion algorithms from YAML configurations.
39
+
40
+ The mixin handles:
41
+ - Configuration loading and composition
42
+ - Target class validation
43
+ - Nested configuration group navigation
44
+ - Object instantiation with proper error handling
45
+
46
+ Example:
47
+
48
+ ```python
49
+ class MyAlgorithm(HydraConfigMixin):
50
+ def __init__(self, param1: str, param2: int = 10):
51
+ self.param1 = param1
52
+ self.param2 = param2
53
+
54
+ # Instantiate from config
55
+ algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
56
+ ```
57
+
58
+ Note:
59
+ This mixin requires Hydra to be properly initialized before use.
60
+ Typically, this is handled by the main FusionBench CLI application.
20
61
  """
21
62
 
22
63
  @classmethod
@@ -24,26 +65,83 @@ class HydraConfigMixin:
24
65
  cls,
25
66
  config_name: Union[str, Path],
26
67
  overrides: Optional[List[str]] = None,
27
- ):
68
+ ) -> T:
69
+ """
70
+ Create an instance of the class from a Hydra configuration.
71
+
72
+ This method loads a Hydra configuration file and instantiates the class
73
+ using the configuration parameters. It supports both direct parameter
74
+ passing and target-based instantiation patterns.
75
+
76
+ Args:
77
+ config_name: The name/path of the configuration file to load.
78
+ Can be a string like "algorithms/simple_average" or
79
+ a Path object. The .yaml extension is optional.
80
+ overrides: Optional list of configuration overrides in the format
81
+ ["key=value", "nested.key=value"]. These allow runtime
82
+ modification of configuration parameters.
83
+
84
+ Returns:
85
+ An instance of the class configured according to the loaded configuration.
86
+
87
+ Raises:
88
+ RuntimeError: If Hydra is not properly initialized.
89
+ ImportError: If a target class specified in the config cannot be imported.
90
+ ValueError: If required configuration parameters are missing.
91
+
92
+ Example:
93
+ ```python
94
+ # Load with basic config
95
+ obj = MyClass.from_config("my_config")
96
+
97
+ # Load with overrides
98
+ obj = MyClass.from_config(
99
+ "my_config",
100
+ overrides=["param1=new_value", "param2=42"]
101
+ )
102
+
103
+ # Load nested config
104
+ obj = MyClass.from_config("category/subcategory/my_config")
105
+ ```
106
+
107
+ Note:
108
+ The method automatically handles nested configuration groups by
109
+ navigating through the configuration hierarchy based on the
110
+ config_name path structure.
111
+ """
112
+ # Verify Hydra initialization
28
113
  if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
29
- raise RuntimeError("Hydra is not initialized.")
114
+ raise RuntimeError(
115
+ "Hydra is not initialized. Please ensure Hydra is properly "
116
+ "initialized before calling from_config(). This is typically "
117
+ "handled by the FusionBench CLI application."
118
+ )
30
119
  else:
120
+ # Compose the configuration with any provided overrides
31
121
  cfg = compose(config_name=config_name, overrides=overrides)
32
122
 
123
+ # Navigate through nested configuration groups
124
+ # E.g., "algorithms/simple_average" -> navigate to cfg.algorithms
33
125
  config_groups = config_name.split("/")[:-1]
34
126
  for config_group in config_groups:
35
127
  cfg = cfg[config_group]
36
128
 
129
+ # Handle target-based instantiation
37
130
  if "_target_" in cfg:
38
- # if the config has a _target_ key, check if it is equal to the class name
131
+ # Validate that the target class matches the calling class
39
132
  target_cls = import_object(cfg["_target_"])
40
133
  if target_cls != cls:
41
134
  log.warning(
42
- f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
135
+ f"Configuration target mismatch: config specifies "
136
+ f"'{cfg['_target_']}' but called on class '{cls.__name__}'. "
137
+ f"This may indicate a configuration error."
43
138
  )
139
+
140
+ # Instantiate using the target pattern with function call logging disabled
44
141
  with set_print_function_call(False):
45
142
  obj = instantiate(cfg)
46
143
  else:
144
+ # Direct instantiation using configuration as keyword arguments
47
145
  obj = cls(**cfg)
48
146
 
49
147
  return obj
@@ -52,9 +52,11 @@ class LightningFabricMixin:
52
52
  and nodes, with support for custom logging via TensorBoard.
53
53
 
54
54
  Attributes:
55
+
55
56
  - _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.
56
57
 
57
58
  Note:
59
+
58
60
  This mixin is designed to be used with classes that require distributed computing capabilities and wish to
59
61
  leverage the Lightning Fabric for this purpose. It assumes the presence of a `config` attribute or parameter
60
62
  in the consuming class for configuration.