fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -9,14 +9,14 @@ Overview of Ties-Merging:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
|
-
from typing import Dict, List, Literal, Mapping, Union # noqa: F401
|
|
12
|
+
from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
from torch import Tensor, nn
|
|
16
16
|
|
|
17
17
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
18
|
from fusion_bench.method import BaseAlgorithm
|
|
19
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
20
20
|
from fusion_bench.modelpool import BaseModelPool
|
|
21
21
|
from fusion_bench.utils.type import StateDictType
|
|
22
22
|
|
|
@@ -25,33 +25,22 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
|
|
|
25
25
|
log = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
scaling_factor (float): The scaling factor to apply to the merged task vector.
|
|
34
|
-
threshold (float): The threshold for resetting values in the task vector.
|
|
35
|
-
remove_keys (List[str]): List of keys to remove from the state dictionary.
|
|
36
|
-
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
40
|
-
"scaling_factor": "scaling_factor",
|
|
41
|
-
"threshold": "threshold",
|
|
42
|
-
"remove_keys": "remove_keys",
|
|
43
|
-
"merge_func": "merge_func",
|
|
44
|
-
}
|
|
45
|
-
|
|
28
|
+
@auto_register_config
|
|
29
|
+
class TiesMergingAlgorithm(
|
|
30
|
+
SimpleProfilerMixin,
|
|
31
|
+
BaseAlgorithm,
|
|
32
|
+
):
|
|
46
33
|
def __init__(
|
|
47
34
|
self,
|
|
48
35
|
scaling_factor: float,
|
|
49
36
|
threshold: float,
|
|
50
37
|
remove_keys: List[str],
|
|
51
38
|
merge_func: Literal["sum", "mean", "max"],
|
|
52
|
-
**kwargs,
|
|
39
|
+
**kwargs: Any,
|
|
53
40
|
):
|
|
54
41
|
"""
|
|
42
|
+
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
|
|
43
|
+
|
|
55
44
|
Initialize the TiesMergingAlgorithm with the given parameters.
|
|
56
45
|
|
|
57
46
|
Args:
|
|
@@ -61,14 +50,12 @@ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
61
50
|
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
62
51
|
**kwargs: Additional keyword arguments for the base class.
|
|
63
52
|
"""
|
|
64
|
-
self.scaling_factor = scaling_factor
|
|
65
|
-
self.threshold = threshold
|
|
66
|
-
self.remove_keys = remove_keys
|
|
67
|
-
self.merge_func = merge_func
|
|
68
53
|
super().__init__(**kwargs)
|
|
69
54
|
|
|
70
55
|
@torch.no_grad()
|
|
71
|
-
def run(
|
|
56
|
+
def run(
|
|
57
|
+
self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs: Any
|
|
58
|
+
) -> nn.Module:
|
|
72
59
|
"""
|
|
73
60
|
Run the TIES merging algorithm to fuse models in the model pool.
|
|
74
61
|
|
|
@@ -2,6 +2,7 @@ import functools
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
+
from typing import Any, Iterator
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
from torch import Tensor
|
|
@@ -38,7 +39,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
38
39
|
|
|
39
40
|
modelpool: CLIPVisionModelPool = None
|
|
40
41
|
|
|
41
|
-
def load_checkpoint(self, model, checkpoint):
|
|
42
|
+
def load_checkpoint(self, model: Any, checkpoint: Any):
|
|
42
43
|
"""
|
|
43
44
|
Load the checkpoint file.
|
|
44
45
|
|
|
@@ -49,7 +50,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
49
50
|
state = {"model": model}
|
|
50
51
|
self._fabric.load(checkpoint, state)
|
|
51
52
|
|
|
52
|
-
def save_checkpoint(self, model, checkpoint):
|
|
53
|
+
def save_checkpoint(self, model: Any, checkpoint: Any):
|
|
53
54
|
"""
|
|
54
55
|
Save the checkpoint file.
|
|
55
56
|
|
|
@@ -102,7 +103,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
102
103
|
return moe_model
|
|
103
104
|
|
|
104
105
|
@functools.cache
|
|
105
|
-
def get_shuffled_test_loader_iter(self, tta_dataset: str):
|
|
106
|
+
def get_shuffled_test_loader_iter(self, tta_dataset: str) -> Iterator:
|
|
106
107
|
"""
|
|
107
108
|
Get an iterator for the shuffled test data loader.
|
|
108
109
|
|
|
@@ -131,7 +132,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
131
132
|
"""
|
|
132
133
|
self.setup_zero_shot_classification_head()
|
|
133
134
|
|
|
134
|
-
def compute_logits(self, module, batch, task) -> Tensor:
|
|
135
|
+
def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
|
|
135
136
|
"""
|
|
136
137
|
Compute the logits for the given batch and task.
|
|
137
138
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from abc import abstractmethod
|
|
3
|
-
from typing import cast # noqa: F401
|
|
3
|
+
from typing import Any, cast # noqa: F401
|
|
4
4
|
|
|
5
5
|
import lightning as L
|
|
6
6
|
import lightning.fabric.wrappers
|
|
@@ -70,7 +70,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
70
70
|
assert "No CUDA device available."
|
|
71
71
|
|
|
72
72
|
@abstractmethod
|
|
73
|
-
def load_checkpoint(self, model, checkpoint):
|
|
73
|
+
def load_checkpoint(self, model: Any, checkpoint: Any):
|
|
74
74
|
"""
|
|
75
75
|
Load the checkpoint file.
|
|
76
76
|
|
|
@@ -81,7 +81,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
81
81
|
pass
|
|
82
82
|
|
|
83
83
|
@abstractmethod
|
|
84
|
-
def save_checkpoint(self, model, checkpoint):
|
|
84
|
+
def save_checkpoint(self, model: Any, checkpoint: Any):
|
|
85
85
|
"""
|
|
86
86
|
Save the checkpoint file.
|
|
87
87
|
|
|
@@ -121,7 +121,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
121
121
|
pass
|
|
122
122
|
|
|
123
123
|
@abstractmethod
|
|
124
|
-
def compute_logits(self, module, batch, task) -> Tensor:
|
|
124
|
+
def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
|
|
125
125
|
"""
|
|
126
126
|
Compute the logits for a given batch and task.
|
|
127
127
|
|
|
@@ -135,7 +135,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
135
135
|
"""
|
|
136
136
|
pass
|
|
137
137
|
|
|
138
|
-
def test_time_adaptation(self, module: WeightEnsemblingMoE):
|
|
138
|
+
def test_time_adaptation(self, module: WeightEnsemblingMoE) -> WeightEnsemblingMoE:
|
|
139
139
|
"""
|
|
140
140
|
Perform test-time adaptation for the given module.
|
|
141
141
|
|
|
@@ -208,7 +208,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
208
208
|
|
|
209
209
|
return module
|
|
210
210
|
|
|
211
|
-
def run(self, modelpool: ModelPool):
|
|
211
|
+
def run(self, modelpool: ModelPool) -> WeightEnsemblingMoE:
|
|
212
212
|
"""
|
|
213
213
|
Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
|
|
214
214
|
|
|
@@ -3,6 +3,7 @@ from typing import List, Mapping, Union # noqa: F401
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
|
+
from transformers import PreTrainedModel
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from fusion_bench.method import BaseAlgorithm
|
|
@@ -10,24 +11,17 @@ from fusion_bench.modelpool import CausalLMPool
|
|
|
10
11
|
from fusion_bench.utils import timeit_context
|
|
11
12
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
12
13
|
from fusion_bench.utils.type import StateDictType
|
|
14
|
+
from fusion_bench.mixins import auto_register_config
|
|
13
15
|
|
|
14
16
|
log = logging.getLogger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
19
|
+
@auto_register_config
|
|
17
20
|
class WeightedAverageForLLama(BaseAlgorithm):
|
|
18
21
|
"""
|
|
19
22
|
A class to perform weighted averaging of LlaMa/Mistral models.
|
|
20
23
|
"""
|
|
21
24
|
|
|
22
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
23
|
-
"normalize": "normalize",
|
|
24
|
-
"weights": "weights",
|
|
25
|
-
"backbone_only": "backbone_only",
|
|
26
|
-
"merged_model_save_path": "merged_model_save_path",
|
|
27
|
-
"save_tokenizer": "save_tokenizer",
|
|
28
|
-
"push_to_hub": "push_to_hub",
|
|
29
|
-
}
|
|
30
|
-
|
|
31
25
|
def __init__(
|
|
32
26
|
self,
|
|
33
27
|
normalize: bool,
|
|
@@ -49,17 +43,11 @@ class WeightedAverageForLLama(BaseAlgorithm):
|
|
|
49
43
|
save_tokenizer (bool): Whether to save the tokenizer.
|
|
50
44
|
push_to_hub (bool): Whether to push the model to the hub.
|
|
51
45
|
"""
|
|
52
|
-
self.normalize = normalize
|
|
53
|
-
self.weights = weights
|
|
54
|
-
self.backbone_only = backbone_only
|
|
55
|
-
self.merged_model_save_path = merged_model_save_path
|
|
56
|
-
self.save_tokenizer = save_tokenizer
|
|
57
|
-
self.push_to_hub = push_to_hub
|
|
58
46
|
super().__init__(**kwargs)
|
|
59
47
|
|
|
60
48
|
@override
|
|
61
49
|
@torch.no_grad()
|
|
62
|
-
def run(self, modelpool: CausalLMPool):
|
|
50
|
+
def run(self, modelpool: CausalLMPool) -> PreTrainedModel:
|
|
63
51
|
"""
|
|
64
52
|
Executes the weighted averaging of models in the provided model pool.
|
|
65
53
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .backward_transfer import compute_backward_transfer
|
|
@@ -10,7 +10,7 @@ def compute_backward_transfer(
|
|
|
10
10
|
Compute the backward transfer (BWT) of a model on a set of tasks.
|
|
11
11
|
|
|
12
12
|
Equation:
|
|
13
|
-
BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{
|
|
13
|
+
$BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{T,i}[k] - acc_{i,i}[k])$
|
|
14
14
|
|
|
15
15
|
Returns:
|
|
16
16
|
float: The backward transfer of the model.
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from .depth import DepthMetric
|
|
2
2
|
from .noise import NoiseMetric
|
|
3
3
|
from .normal import NormalMetric
|
|
4
|
-
from .segmentation import
|
|
4
|
+
from .segmentation import SegmentationMetric
|
|
5
5
|
|
|
6
6
|
metric_classes = {
|
|
7
|
-
"segmentation":
|
|
7
|
+
"segmentation": SegmentationMetric,
|
|
8
8
|
"depth": DepthMetric,
|
|
9
9
|
"normal": NormalMetric,
|
|
10
10
|
"noise": NoiseMetric,
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -11,7 +11,11 @@ _import_structure = {
|
|
|
11
11
|
"hydra_config": ["HydraConfigMixin"],
|
|
12
12
|
"lightning_fabric": ["LightningFabricMixin"],
|
|
13
13
|
"openclip_classification": ["OpenCLIPClassificationMixin"],
|
|
14
|
-
"serialization": [
|
|
14
|
+
"serialization": [
|
|
15
|
+
"BaseYAMLSerializable",
|
|
16
|
+
"YAMLSerializationMixin",
|
|
17
|
+
"auto_register_config",
|
|
18
|
+
],
|
|
15
19
|
"simple_profiler": ["SimpleProfilerMixin"],
|
|
16
20
|
}
|
|
17
21
|
|
|
@@ -21,7 +25,11 @@ if TYPE_CHECKING:
|
|
|
21
25
|
from .hydra_config import HydraConfigMixin
|
|
22
26
|
from .lightning_fabric import LightningFabricMixin
|
|
23
27
|
from .openclip_classification import OpenCLIPClassificationMixin
|
|
24
|
-
from .serialization import
|
|
28
|
+
from .serialization import (
|
|
29
|
+
BaseYAMLSerializable,
|
|
30
|
+
YAMLSerializationMixin,
|
|
31
|
+
auto_register_config,
|
|
32
|
+
)
|
|
25
33
|
from .simple_profiler import SimpleProfilerMixin
|
|
26
34
|
else:
|
|
27
35
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -6,6 +6,7 @@ from typing import ( # noqa: F401
|
|
|
6
6
|
TYPE_CHECKING,
|
|
7
7
|
Any,
|
|
8
8
|
Dict,
|
|
9
|
+
Iterator,
|
|
9
10
|
List,
|
|
10
11
|
Optional,
|
|
11
12
|
Tuple,
|
|
@@ -48,7 +49,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
48
49
|
- `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
|
|
49
50
|
"""
|
|
50
51
|
|
|
51
|
-
|
|
52
|
+
dataloader_kwargs: Dict[str, Any] = {}
|
|
52
53
|
# the modelpool is set by inheriting class
|
|
53
54
|
modelpool: CLIPVisionModelPool = None
|
|
54
55
|
_clip_processor: CLIPProcessor = None
|
|
@@ -71,7 +72,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
71
72
|
batch_size: Optional[int] = None,
|
|
72
73
|
num_workers: Optional[int] = None,
|
|
73
74
|
**loader_kwargs,
|
|
74
|
-
):
|
|
75
|
+
) -> Iterator:
|
|
75
76
|
"""
|
|
76
77
|
Get an iterator for a shuffled test DataLoader.
|
|
77
78
|
|
|
@@ -89,7 +90,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
89
90
|
Iterator: An iterator over the shuffled test DataLoader.
|
|
90
91
|
"""
|
|
91
92
|
# get dataloader kwargs
|
|
92
|
-
dataloader_kwargs = self.
|
|
93
|
+
dataloader_kwargs = self.dataloader_kwargs.copy()
|
|
93
94
|
dataloader_kwargs["shuffle"] = True
|
|
94
95
|
if batch_size is not None:
|
|
95
96
|
dataloader_kwargs["batch_size"] = batch_size
|
|
@@ -1,8 +1,20 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hydra Configuration Mixin for FusionBench.
|
|
3
|
+
|
|
4
|
+
This module provides a mixin class that enables easy instantiation of objects
|
|
5
|
+
from Hydra configuration files. It's designed to work seamlessly with the
|
|
6
|
+
FusionBench configuration system and supports dynamic object creation based
|
|
7
|
+
on YAML configuration files.
|
|
8
|
+
|
|
9
|
+
The mixin integrates with Hydra's configuration management system to provide
|
|
10
|
+
a clean interface for creating objects from structured configurations.
|
|
11
|
+
"""
|
|
12
|
+
|
|
1
13
|
import logging
|
|
2
14
|
import os
|
|
3
15
|
from copy import deepcopy
|
|
4
16
|
from pathlib import Path
|
|
5
|
-
from typing import Dict, List, Optional, Union
|
|
17
|
+
from typing import Dict, List, Optional, TypeVar, Union
|
|
6
18
|
|
|
7
19
|
import hydra.core.global_hydra
|
|
8
20
|
from hydra import compose, initialize
|
|
@@ -13,10 +25,39 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
|
13
25
|
|
|
14
26
|
log = logging.getLogger(__name__)
|
|
15
27
|
|
|
28
|
+
T = TypeVar("T", bound="HydraConfigMixin")
|
|
29
|
+
|
|
16
30
|
|
|
17
31
|
class HydraConfigMixin:
|
|
18
|
-
"""
|
|
19
|
-
A mixin
|
|
32
|
+
R"""
|
|
33
|
+
A mixin class that provides configuration-based instantiation capabilities.
|
|
34
|
+
|
|
35
|
+
This mixin enables classes to be instantiated directly from Hydra configuration
|
|
36
|
+
files, supporting both direct instantiation and target-based instantiation patterns.
|
|
37
|
+
It's particularly useful in FusionBench for creating model pools, task pools,
|
|
38
|
+
and fusion algorithms from YAML configurations.
|
|
39
|
+
|
|
40
|
+
The mixin handles:
|
|
41
|
+
- Configuration loading and composition
|
|
42
|
+
- Target class validation
|
|
43
|
+
- Nested configuration group navigation
|
|
44
|
+
- Object instantiation with proper error handling
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
class MyAlgorithm(HydraConfigMixin):
|
|
50
|
+
def __init__(self, param1: str, param2: int = 10):
|
|
51
|
+
self.param1 = param1
|
|
52
|
+
self.param2 = param2
|
|
53
|
+
|
|
54
|
+
# Instantiate from config
|
|
55
|
+
algorithm = MyAlgorithm.from_config("algorithms/my_algorithm")
|
|
56
|
+
```
|
|
57
|
+
|
|
58
|
+
Note:
|
|
59
|
+
This mixin requires Hydra to be properly initialized before use.
|
|
60
|
+
Typically, this is handled by the main FusionBench CLI application.
|
|
20
61
|
"""
|
|
21
62
|
|
|
22
63
|
@classmethod
|
|
@@ -24,26 +65,83 @@ class HydraConfigMixin:
|
|
|
24
65
|
cls,
|
|
25
66
|
config_name: Union[str, Path],
|
|
26
67
|
overrides: Optional[List[str]] = None,
|
|
27
|
-
):
|
|
68
|
+
) -> T:
|
|
69
|
+
"""
|
|
70
|
+
Create an instance of the class from a Hydra configuration.
|
|
71
|
+
|
|
72
|
+
This method loads a Hydra configuration file and instantiates the class
|
|
73
|
+
using the configuration parameters. It supports both direct parameter
|
|
74
|
+
passing and target-based instantiation patterns.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
config_name: The name/path of the configuration file to load.
|
|
78
|
+
Can be a string like "algorithms/simple_average" or
|
|
79
|
+
a Path object. The .yaml extension is optional.
|
|
80
|
+
overrides: Optional list of configuration overrides in the format
|
|
81
|
+
["key=value", "nested.key=value"]. These allow runtime
|
|
82
|
+
modification of configuration parameters.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
An instance of the class configured according to the loaded configuration.
|
|
86
|
+
|
|
87
|
+
Raises:
|
|
88
|
+
RuntimeError: If Hydra is not properly initialized.
|
|
89
|
+
ImportError: If a target class specified in the config cannot be imported.
|
|
90
|
+
ValueError: If required configuration parameters are missing.
|
|
91
|
+
|
|
92
|
+
Example:
|
|
93
|
+
```python
|
|
94
|
+
# Load with basic config
|
|
95
|
+
obj = MyClass.from_config("my_config")
|
|
96
|
+
|
|
97
|
+
# Load with overrides
|
|
98
|
+
obj = MyClass.from_config(
|
|
99
|
+
"my_config",
|
|
100
|
+
overrides=["param1=new_value", "param2=42"]
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Load nested config
|
|
104
|
+
obj = MyClass.from_config("category/subcategory/my_config")
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
Note:
|
|
108
|
+
The method automatically handles nested configuration groups by
|
|
109
|
+
navigating through the configuration hierarchy based on the
|
|
110
|
+
config_name path structure.
|
|
111
|
+
"""
|
|
112
|
+
# Verify Hydra initialization
|
|
28
113
|
if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
|
|
29
|
-
raise RuntimeError(
|
|
114
|
+
raise RuntimeError(
|
|
115
|
+
"Hydra is not initialized. Please ensure Hydra is properly "
|
|
116
|
+
"initialized before calling from_config(). This is typically "
|
|
117
|
+
"handled by the FusionBench CLI application."
|
|
118
|
+
)
|
|
30
119
|
else:
|
|
120
|
+
# Compose the configuration with any provided overrides
|
|
31
121
|
cfg = compose(config_name=config_name, overrides=overrides)
|
|
32
122
|
|
|
123
|
+
# Navigate through nested configuration groups
|
|
124
|
+
# E.g., "algorithms/simple_average" -> navigate to cfg.algorithms
|
|
33
125
|
config_groups = config_name.split("/")[:-1]
|
|
34
126
|
for config_group in config_groups:
|
|
35
127
|
cfg = cfg[config_group]
|
|
36
128
|
|
|
129
|
+
# Handle target-based instantiation
|
|
37
130
|
if "_target_" in cfg:
|
|
38
|
-
#
|
|
131
|
+
# Validate that the target class matches the calling class
|
|
39
132
|
target_cls = import_object(cfg["_target_"])
|
|
40
133
|
if target_cls != cls:
|
|
41
134
|
log.warning(
|
|
42
|
-
f"
|
|
135
|
+
f"Configuration target mismatch: config specifies "
|
|
136
|
+
f"'{cfg['_target_']}' but called on class '{cls.__name__}'. "
|
|
137
|
+
f"This may indicate a configuration error."
|
|
43
138
|
)
|
|
139
|
+
|
|
140
|
+
# Instantiate using the target pattern with function call logging disabled
|
|
44
141
|
with set_print_function_call(False):
|
|
45
142
|
obj = instantiate(cfg)
|
|
46
143
|
else:
|
|
144
|
+
# Direct instantiation using configuration as keyword arguments
|
|
47
145
|
obj = cls(**cfg)
|
|
48
146
|
|
|
49
147
|
return obj
|
|
@@ -52,9 +52,11 @@ class LightningFabricMixin:
|
|
|
52
52
|
and nodes, with support for custom logging via TensorBoard.
|
|
53
53
|
|
|
54
54
|
Attributes:
|
|
55
|
+
|
|
55
56
|
- _fabric (L.Fabric): The Lightning Fabric instance used for distributed computing.
|
|
56
57
|
|
|
57
58
|
Note:
|
|
59
|
+
|
|
58
60
|
This mixin is designed to be used with classes that require distributed computing capabilities and wish to
|
|
59
61
|
leverage the Lightning Fabric for this purpose. It assumes the presence of a `config` attribute or parameter
|
|
60
62
|
in the consuming class for configuration.
|