fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +40 -83
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +10 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +36 -11
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +85 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/rich_utils.py +7 -3
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
fusion_bench/method/ensemble.py
CHANGED
|
@@ -17,7 +17,21 @@ from fusion_bench.models.wrappers.ensemble import (
|
|
|
17
17
|
log = logging.getLogger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
@auto_register_config
|
|
20
21
|
class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initializes the SimpleEnsembleAlgorithm with an optional device map.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(**kwargs)
|
|
34
|
+
|
|
21
35
|
@torch.no_grad()
|
|
22
36
|
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
|
|
23
37
|
"""
|
|
@@ -30,9 +44,10 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
|
30
44
|
EnsembleModule: The ensembled model.
|
|
31
45
|
"""
|
|
32
46
|
log.info(f"Running ensemble algorithm with {len(modelpool)} models")
|
|
33
|
-
|
|
34
47
|
models = [modelpool.load_model(m) for m in modelpool.model_names]
|
|
35
|
-
|
|
48
|
+
|
|
49
|
+
log.info("creating ensemble module")
|
|
50
|
+
ensemble = EnsembleModule(models=models, device_map=self.device_map)
|
|
36
51
|
return ensemble
|
|
37
52
|
|
|
38
53
|
|
|
@@ -23,6 +23,7 @@ from transformers import MixtralForCausalLM
|
|
|
23
23
|
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
|
|
24
24
|
|
|
25
25
|
import fusion_bench as fb
|
|
26
|
+
from fusion_bench import auto_register_config
|
|
26
27
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
27
28
|
build_calib_loader,
|
|
28
29
|
)
|
|
@@ -97,6 +98,7 @@ def dynamic_skipping(
|
|
|
97
98
|
return model, (res_median, res_mean)
|
|
98
99
|
|
|
99
100
|
|
|
101
|
+
@auto_register_config
|
|
100
102
|
class DynamicSkippingPruningForMixtral(
|
|
101
103
|
fb.BaseAlgorithm,
|
|
102
104
|
fb.mixins.LightningFabricMixin,
|
|
@@ -22,6 +22,7 @@ from transformers import MixtralForCausalLM
|
|
|
22
22
|
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
|
|
23
23
|
|
|
24
24
|
import fusion_bench as fb
|
|
25
|
+
from fusion_bench import auto_register_config
|
|
25
26
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
26
27
|
build_calib_loader,
|
|
27
28
|
)
|
|
@@ -81,6 +82,7 @@ def layerwise_pruning(
|
|
|
81
82
|
return model, (global_loss_history,)
|
|
82
83
|
|
|
83
84
|
|
|
85
|
+
@auto_register_config
|
|
84
86
|
class LayerWisePruningForMixtral(
|
|
85
87
|
fb.BaseAlgorithm,
|
|
86
88
|
fb.mixins.LightningFabricMixin,
|
|
@@ -20,6 +20,7 @@ from tqdm import tqdm
|
|
|
20
20
|
from transformers import MixtralForCausalLM
|
|
21
21
|
|
|
22
22
|
import fusion_bench as fb
|
|
23
|
+
from fusion_bench import auto_register_config
|
|
23
24
|
from fusion_bench.method.expert_sparsity.utils.calibration_data import (
|
|
24
25
|
build_calib_loader,
|
|
25
26
|
)
|
|
@@ -95,6 +96,7 @@ def progressive_pruning(
|
|
|
95
96
|
return model, (global_loss_history,)
|
|
96
97
|
|
|
97
98
|
|
|
99
|
+
@auto_register_config
|
|
98
100
|
class ProgressivePruningForMixtral(
|
|
99
101
|
fb.BaseAlgorithm,
|
|
100
102
|
fb.mixins.LightningFabricMixin,
|
|
@@ -2,5 +2,9 @@
|
|
|
2
2
|
from .expo import ExPOAlgorithm
|
|
3
3
|
from .linear_interpolation import LinearInterpolationAlgorithm
|
|
4
4
|
from .llama_expo import ExPOAlgorithmForLlama
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
5
|
+
from .simple_average_for_causallm import SimpleAverageForCausalLM, SimpleAverageForLlama
|
|
6
|
+
from .task_arithmetic_for_causallm import (
|
|
7
|
+
TaskArithmeticForCausalLM,
|
|
8
|
+
TaskArithmeticForLlama,
|
|
9
|
+
)
|
|
10
|
+
from .ties_merging_for_causallm import TiesMergingForCausalLM
|
|
@@ -18,16 +18,16 @@ log = get_rankzero_logger(__name__)
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@auto_register_config
|
|
21
|
-
class
|
|
21
|
+
class SimpleAverageForCausalLM(BaseAlgorithm):
|
|
22
22
|
R"""
|
|
23
23
|
A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
|
|
24
24
|
|
|
25
25
|
Examples:
|
|
26
|
-
The following example demonstrates how to use the `
|
|
26
|
+
The following example demonstrates how to use the `SimpleAverageForCausalLM` algorithm to merge Mistral models.
|
|
27
27
|
|
|
28
28
|
```bash
|
|
29
29
|
fusion_bench \
|
|
30
|
-
method=linear/
|
|
30
|
+
method=linear/simple_average_for_causallm \
|
|
31
31
|
method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
|
|
32
32
|
modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
|
|
33
33
|
```
|
|
@@ -35,7 +35,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
35
35
|
|
|
36
36
|
def __init__(
|
|
37
37
|
self,
|
|
38
|
-
merge_backbone: bool,
|
|
38
|
+
merge_backbone: bool = False,
|
|
39
39
|
model_save_path: Optional[str] = None,
|
|
40
40
|
show_pbar: bool = False,
|
|
41
41
|
**kwargs,
|
|
@@ -81,3 +81,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
81
81
|
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
82
82
|
f.write(model_card_str)
|
|
83
83
|
return model
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
SimpleAverageForLlama = SimpleAverageForCausalLM
|
|
87
|
+
"""Alias for SimpleAverageForCausalLM"""
|
|
@@ -1,22 +1,27 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
import os
|
|
2
3
|
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
3
4
|
|
|
4
5
|
from typing_extensions import override
|
|
5
6
|
|
|
6
|
-
from fusion_bench import timeit_context
|
|
7
|
+
from fusion_bench import auto_register_config, timeit_context
|
|
7
8
|
from fusion_bench.method import TaskArithmeticAlgorithm
|
|
8
9
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
9
10
|
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
11
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
10
12
|
|
|
11
13
|
log = logging.getLogger(__name__)
|
|
12
14
|
|
|
13
15
|
|
|
14
|
-
|
|
16
|
+
@auto_register_config
|
|
17
|
+
class TaskArithmeticForCausalLM(
|
|
18
|
+
TaskArithmeticAlgorithm,
|
|
19
|
+
):
|
|
15
20
|
R"""
|
|
16
21
|
Examples:
|
|
17
22
|
|
|
18
23
|
fusion_bench \
|
|
19
|
-
method=linear/
|
|
24
|
+
method=linear/task_arithmetic_for_causallm \
|
|
20
25
|
method.scaling_factor=0.3 \
|
|
21
26
|
method.model_save_path=outputs/simle_mixtral_exp_v4/task_arithmetic_0.3 \
|
|
22
27
|
modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
|
|
@@ -29,18 +34,14 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
|
|
|
29
34
|
def __init__(
|
|
30
35
|
self,
|
|
31
36
|
scaling_factor: float,
|
|
32
|
-
merge_backbone: bool,
|
|
37
|
+
merge_backbone: bool = False,
|
|
33
38
|
model_save_path: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
34
40
|
):
|
|
35
|
-
|
|
36
|
-
self.model_save_path = model_save_path
|
|
37
|
-
super().__init__(scaling_factor=scaling_factor)
|
|
41
|
+
super().__init__(scaling_factor=scaling_factor, **kwargs)
|
|
38
42
|
|
|
39
43
|
@override
|
|
40
44
|
def run(self, modelpool: CausalLMPool):
|
|
41
|
-
if self.model_save_path:
|
|
42
|
-
tokenizer = modelpool.load_tokenizer()
|
|
43
|
-
|
|
44
45
|
if self.merge_backbone:
|
|
45
46
|
assert modelpool.has_pretrained
|
|
46
47
|
backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
|
|
@@ -52,6 +53,15 @@ class TaskArithmeticForLlama(TaskArithmeticAlgorithm, SimpleProfilerMixin):
|
|
|
52
53
|
|
|
53
54
|
if self.model_save_path is not None:
|
|
54
55
|
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
description = f"Merged model using task arithmetic with scaling factor {self.scaling_factor}."
|
|
57
|
+
modelpool.save_model(
|
|
58
|
+
model=model,
|
|
59
|
+
path=self.model_save_path,
|
|
60
|
+
save_tokenizer=True,
|
|
61
|
+
algorithm_config=self.config,
|
|
62
|
+
description=description,
|
|
63
|
+
)
|
|
57
64
|
return model
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
TaskArithmeticForLlama = TaskArithmeticForCausalLM
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from typing import Dict, List, Mapping, Optional, TypeVar, Union # noqa: F401
|
|
4
|
+
|
|
5
|
+
from typing_extensions import override
|
|
6
|
+
|
|
7
|
+
from fusion_bench import auto_register_config, timeit_context
|
|
8
|
+
from fusion_bench.method import TiesMergingAlgorithm
|
|
9
|
+
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
10
|
+
from fusion_bench.modelpool import CausalLMBackbonePool, CausalLMPool
|
|
11
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@auto_register_config
|
|
17
|
+
class TiesMergingForCausalLM(
|
|
18
|
+
TiesMergingAlgorithm,
|
|
19
|
+
):
|
|
20
|
+
R"""
|
|
21
|
+
TIES merging algorithm for CausalLM models.
|
|
22
|
+
|
|
23
|
+
This class extends the TiesMergingAlgorithm to work specifically with CausalLM models,
|
|
24
|
+
providing model saving capabilities and backbone merging support.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
_config_mapping = TiesMergingAlgorithm._config_mapping | {
|
|
28
|
+
"merge_backbone": "merge_backbone",
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
scaling_factor: float,
|
|
34
|
+
threshold: float,
|
|
35
|
+
remove_keys: List[str] = None,
|
|
36
|
+
merge_func: str = "sum",
|
|
37
|
+
merge_backbone: bool = False,
|
|
38
|
+
model_save_path: Optional[str] = None,
|
|
39
|
+
**kwargs,
|
|
40
|
+
):
|
|
41
|
+
super().__init__(
|
|
42
|
+
scaling_factor=scaling_factor,
|
|
43
|
+
threshold=threshold,
|
|
44
|
+
remove_keys=remove_keys,
|
|
45
|
+
merge_func=merge_func,
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
@override
|
|
50
|
+
def run(self, modelpool: CausalLMPool):
|
|
51
|
+
if self.merge_backbone:
|
|
52
|
+
assert modelpool.has_pretrained
|
|
53
|
+
backbone_modelpool = CausalLMBackbonePool(**modelpool.config)
|
|
54
|
+
model = modelpool.load_model("_pretrained_")
|
|
55
|
+
backbone_model = super().run(backbone_modelpool)
|
|
56
|
+
model.model.layers = backbone_model
|
|
57
|
+
else:
|
|
58
|
+
model = super().run(modelpool)
|
|
59
|
+
|
|
60
|
+
if self.model_save_path is not None:
|
|
61
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
62
|
+
description = f"Merged model using TIES merging with scaling factor {self.scaling_factor} and threshold {self.threshold}."
|
|
63
|
+
modelpool.save_model(
|
|
64
|
+
model=model,
|
|
65
|
+
path=self.model_save_path,
|
|
66
|
+
save_tokenizer=True,
|
|
67
|
+
algorithm_config=self.config,
|
|
68
|
+
description=description,
|
|
69
|
+
)
|
|
70
|
+
return model
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .model_stock import ModelStock
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
import math
|
|
4
|
+
import os
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from typing import Dict, List, Optional, Union
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
from omegaconf import DictConfig
|
|
12
|
+
from torch import nn
|
|
13
|
+
from transformers import PreTrainedModel
|
|
14
|
+
|
|
15
|
+
import fusion_bench
|
|
16
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
17
|
+
from fusion_bench.mixins import SimpleProfilerMixin
|
|
18
|
+
from fusion_bench.models import create_default_model_card
|
|
19
|
+
from fusion_bench.utils.type import StateDictType
|
|
20
|
+
|
|
21
|
+
log = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
EPS = 1e-8
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def compute_angle(
|
|
27
|
+
state_dict_1: StateDictType,
|
|
28
|
+
state_dict_2: StateDictType,
|
|
29
|
+
ref_state_dict: StateDictType,
|
|
30
|
+
ignore_keys: List[str] = [],
|
|
31
|
+
return_cos: bool = False,
|
|
32
|
+
) -> Dict[str, float]:
|
|
33
|
+
"""
|
|
34
|
+
Compute the angle between two state dictionaries relative to a reference state dictionary.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
state_dict_1: First state dictionary
|
|
38
|
+
state_dict_2: Second state dictionary
|
|
39
|
+
ref_state_dict: Reference state dictionary (typically pre-trained model)
|
|
40
|
+
ignore_keys: Keys to ignore during computation
|
|
41
|
+
return_cos: If True, return cosine values instead of angles in degrees
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
Dictionary mapping parameter names to angles (in degrees) or cosine values
|
|
45
|
+
"""
|
|
46
|
+
# Remove the keys not used for CLIP fine-tuning (from the notebook example)
|
|
47
|
+
|
|
48
|
+
return_dict = OrderedDict()
|
|
49
|
+
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
for key in ref_state_dict:
|
|
52
|
+
if key in ignore_keys:
|
|
53
|
+
log.info(f"Ignoring key '{key}'")
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
state_dict_1_val = state_dict_1[key]
|
|
57
|
+
state_dict_2_val = state_dict_2[key]
|
|
58
|
+
ref_val = ref_state_dict[key]
|
|
59
|
+
|
|
60
|
+
if not (state_dict_1_val.shape == state_dict_2_val.shape == ref_val.shape):
|
|
61
|
+
log.warning(
|
|
62
|
+
f"Shape mismatch for key '{key}', ignored during merging: "
|
|
63
|
+
f"({state_dict_1_val.shape}, {state_dict_2_val.shape}, {ref_val.shape})"
|
|
64
|
+
)
|
|
65
|
+
continue
|
|
66
|
+
|
|
67
|
+
vector1 = (state_dict_1_val - ref_val).clone().detach()
|
|
68
|
+
vector2 = (state_dict_2_val - ref_val).clone().detach()
|
|
69
|
+
|
|
70
|
+
vector1 = vector1.float()
|
|
71
|
+
vector2 = vector2.float()
|
|
72
|
+
|
|
73
|
+
cosine_val = torch.sum(vector1 * vector2) / (
|
|
74
|
+
math.sqrt(torch.sum(vector1**2) * torch.sum(vector2**2)) + EPS
|
|
75
|
+
)
|
|
76
|
+
cosine_val = torch.clamp(
|
|
77
|
+
cosine_val, min=-1.0, max=1.0
|
|
78
|
+
) # Prevent nan from acos
|
|
79
|
+
|
|
80
|
+
if return_cos:
|
|
81
|
+
return_dict[key] = cosine_val.item()
|
|
82
|
+
else:
|
|
83
|
+
return_dict[key] = np.rad2deg(
|
|
84
|
+
torch.acos(cosine_val).detach().cpu().item()
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
return return_dict
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def compute_ratio(angle_dict: Dict[str, float], k: int = 2) -> Dict[str, float]:
|
|
91
|
+
"""
|
|
92
|
+
Compute interpolation ratios based on angles between fine-tuned models.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
angle_dict: Dictionary mapping parameter names to angles in degrees
|
|
96
|
+
k: Number of fine-tuned models (default: 2)
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Dictionary mapping parameter names to interpolation ratios
|
|
100
|
+
"""
|
|
101
|
+
ratio_dict = {}
|
|
102
|
+
for key in angle_dict.keys():
|
|
103
|
+
angle = np.deg2rad(angle_dict[key])
|
|
104
|
+
ratio_dict[key] = k * np.cos(angle) / ((k - 1) * np.cos(angle) + 1 + EPS)
|
|
105
|
+
return ratio_dict
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def merge_weights(
|
|
109
|
+
w1: StateDictType, w2: StateDictType, w0: StateDictType, ratio: Dict[str, float]
|
|
110
|
+
) -> StateDictType:
|
|
111
|
+
"""
|
|
112
|
+
Merge model weights using ModelStock formula.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
w1: First fine-tuned model weights
|
|
116
|
+
w2: Second fine-tuned model weights
|
|
117
|
+
w0: Pre-trained model weights
|
|
118
|
+
ratio: Interpolation ratios for each parameter
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Merged model weights
|
|
122
|
+
"""
|
|
123
|
+
# Compute w12 = (w1 + w2) / 2
|
|
124
|
+
w12 = {}
|
|
125
|
+
for key in w1.keys():
|
|
126
|
+
w12[key] = (w1[key].clone() + w2[key].clone()) / 2.0
|
|
127
|
+
|
|
128
|
+
# Apply ModelStock formula: w_merge = t * w12 + (1-t) * w0
|
|
129
|
+
w_merge = copy.deepcopy(w12)
|
|
130
|
+
for key, r in ratio.items():
|
|
131
|
+
w_merge[key] = w12[key].clone() * r + w0[key].clone() * (1.0 - r)
|
|
132
|
+
|
|
133
|
+
return w_merge
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@fusion_bench.auto_register_config
|
|
137
|
+
class ModelStock(SimpleProfilerMixin, BaseAlgorithm):
|
|
138
|
+
"""
|
|
139
|
+
Model Stock: All we need is just a few fine-tuned models
|
|
140
|
+
|
|
141
|
+
This method merges fine-tuned models by interpolating between their average
|
|
142
|
+
and a pre-trained anchor model, with interpolation ratios determined by
|
|
143
|
+
the angle between fine-tuned models in parameter space.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def __init__(
|
|
147
|
+
self,
|
|
148
|
+
ignore_keys: Optional[List[str]] = None,
|
|
149
|
+
model_save_path: Optional[str] = None,
|
|
150
|
+
model_save_kwargs: Optional[DictConfig] = None,
|
|
151
|
+
**kwargs,
|
|
152
|
+
):
|
|
153
|
+
"""
|
|
154
|
+
Initialize ModelStock algorithm.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
ignore_keys: Additional parameter keys to ignore during merging
|
|
158
|
+
"""
|
|
159
|
+
super().__init__(**kwargs)
|
|
160
|
+
if self.ignore_keys is None:
|
|
161
|
+
self.ignore_keys = []
|
|
162
|
+
if self.model_save_kwargs is None:
|
|
163
|
+
self.model_save_kwargs = DictConfig({})
|
|
164
|
+
|
|
165
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
166
|
+
"""
|
|
167
|
+
Run the ModelStock merging algorithm.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
modelpool: Pool of models containing pre-trained and fine-tuned models
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
Merged model
|
|
174
|
+
"""
|
|
175
|
+
with self.profile("model loading"):
|
|
176
|
+
# Load the pre-trained model (anchor)
|
|
177
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
178
|
+
if isinstance(pretrained_model, fusion_bench.LazyStateDict):
|
|
179
|
+
assert (
|
|
180
|
+
pretrained_model.meta_module is not None
|
|
181
|
+
), "Meta module is not initialized"
|
|
182
|
+
pretrained_state_dict = pretrained_model.state_dict()
|
|
183
|
+
|
|
184
|
+
# Load all fine-tuned models
|
|
185
|
+
finetuned_models = []
|
|
186
|
+
finetuned_state_dicts = []
|
|
187
|
+
|
|
188
|
+
for model_name in modelpool.model_names:
|
|
189
|
+
model = modelpool.load_model(model_name)
|
|
190
|
+
finetuned_models.append(model)
|
|
191
|
+
finetuned_state_dicts.append(model.state_dict())
|
|
192
|
+
log.info(f"Loaded fine-tuned model: {model_name}")
|
|
193
|
+
|
|
194
|
+
if len(finetuned_models) < 2:
|
|
195
|
+
raise ValueError("ModelStock requires at least 2 fine-tuned models")
|
|
196
|
+
|
|
197
|
+
log.info(f"Running ModelStock with {len(finetuned_models)} fine-tuned models")
|
|
198
|
+
|
|
199
|
+
with self.profile("compute angles and ratios"):
|
|
200
|
+
if len(finetuned_models) == 2:
|
|
201
|
+
# Two fine-tuned models case
|
|
202
|
+
angle_dict = compute_angle(
|
|
203
|
+
finetuned_state_dicts[0],
|
|
204
|
+
finetuned_state_dicts[1],
|
|
205
|
+
pretrained_state_dict,
|
|
206
|
+
ignore_keys=self.ignore_keys,
|
|
207
|
+
)
|
|
208
|
+
ratio_dict = compute_ratio(angle_dict, k=2)
|
|
209
|
+
|
|
210
|
+
log.info(f"Computed angles for {len(angle_dict)} parameter groups")
|
|
211
|
+
|
|
212
|
+
else:
|
|
213
|
+
# N fine-tuned models case - compute average angle
|
|
214
|
+
angles_sum = {}
|
|
215
|
+
angles_count = {}
|
|
216
|
+
|
|
217
|
+
# Compute pairwise angles and average them
|
|
218
|
+
for i in range(len(finetuned_models)):
|
|
219
|
+
for j in range(i + 1, len(finetuned_models)):
|
|
220
|
+
angle_dict = compute_angle(
|
|
221
|
+
finetuned_state_dicts[i],
|
|
222
|
+
finetuned_state_dicts[j],
|
|
223
|
+
pretrained_state_dict,
|
|
224
|
+
ignore_keys=self.ignore_keys,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
for key, angle in angle_dict.items():
|
|
228
|
+
if key not in angles_sum:
|
|
229
|
+
angles_sum[key] = 0
|
|
230
|
+
angles_count[key] = 0
|
|
231
|
+
angles_sum[key] += angle
|
|
232
|
+
angles_count[key] += 1
|
|
233
|
+
|
|
234
|
+
# Average the angles
|
|
235
|
+
avg_angle_dict = {}
|
|
236
|
+
for key in angles_sum:
|
|
237
|
+
avg_angle_dict[key] = angles_sum[key] / angles_count[key]
|
|
238
|
+
|
|
239
|
+
ratio_dict = compute_ratio(avg_angle_dict, k=len(finetuned_models))
|
|
240
|
+
|
|
241
|
+
log.info(
|
|
242
|
+
f"Computed average angles for {len(avg_angle_dict)} parameter groups"
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
with self.profile("merge weights"):
|
|
246
|
+
if len(finetuned_models) == 2:
|
|
247
|
+
# Direct merging for two models
|
|
248
|
+
merged_state_dict = merge_weights(
|
|
249
|
+
finetuned_state_dicts[0],
|
|
250
|
+
finetuned_state_dicts[1],
|
|
251
|
+
pretrained_state_dict,
|
|
252
|
+
ratio_dict,
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
# For N models, first compute the average of fine-tuned models
|
|
256
|
+
avg_finetuned_state_dict = {}
|
|
257
|
+
for key in finetuned_state_dicts[0].keys():
|
|
258
|
+
avg_finetuned_state_dict[key] = torch.zeros_like(
|
|
259
|
+
finetuned_state_dicts[0][key]
|
|
260
|
+
)
|
|
261
|
+
for state_dict in finetuned_state_dicts:
|
|
262
|
+
avg_finetuned_state_dict[key] += state_dict[key]
|
|
263
|
+
avg_finetuned_state_dict[key] /= len(finetuned_state_dicts)
|
|
264
|
+
|
|
265
|
+
# Apply ModelStock formula: w_H = t * w_avg + (1-t) * w_0
|
|
266
|
+
merged_state_dict = copy.deepcopy(avg_finetuned_state_dict)
|
|
267
|
+
for key, r in ratio_dict.items():
|
|
268
|
+
merged_state_dict[key] = avg_finetuned_state_dict[
|
|
269
|
+
key
|
|
270
|
+
].clone() * r + pretrained_state_dict[key].clone() * (1.0 - r)
|
|
271
|
+
|
|
272
|
+
# Load merged weights into the model
|
|
273
|
+
if isinstance(pretrained_model, nn.Module):
|
|
274
|
+
result_model = pretrained_model
|
|
275
|
+
elif isinstance(pretrained_model, fusion_bench.LazyStateDict):
|
|
276
|
+
result_model = deepcopy(pretrained_model.meta_module)
|
|
277
|
+
result_model.to(device=pretrained_model._device)
|
|
278
|
+
result = result_model.load_state_dict(merged_state_dict, strict=False)
|
|
279
|
+
|
|
280
|
+
if result.unexpected_keys:
|
|
281
|
+
raise RuntimeError(
|
|
282
|
+
f"Unexpected keys in state dict: {result.unexpected_keys}"
|
|
283
|
+
)
|
|
284
|
+
if result.missing_keys:
|
|
285
|
+
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
286
|
+
|
|
287
|
+
if self.model_save_path is not None:
|
|
288
|
+
with self.profile("model saving"):
|
|
289
|
+
modelpool.save_model(
|
|
290
|
+
model, path=self.model_save_path, **self.model_save_kwargs
|
|
291
|
+
)
|
|
292
|
+
if isinstance(model, PreTrainedModel):
|
|
293
|
+
modelcard = create_default_model_card(
|
|
294
|
+
models=[
|
|
295
|
+
modelpool.get_model_path(m)
|
|
296
|
+
for m in modelpool.all_model_names
|
|
297
|
+
],
|
|
298
|
+
description="Merged model using [Model Stock](https://arxiv.org/abs/2403.19522).",
|
|
299
|
+
algorithm_config=self.config,
|
|
300
|
+
modelpool_config=modelpool.config,
|
|
301
|
+
)
|
|
302
|
+
with open(
|
|
303
|
+
os.path.join(self.model_save_path, "README.md"), "w"
|
|
304
|
+
) as f:
|
|
305
|
+
f.write(modelcard)
|
|
306
|
+
|
|
307
|
+
self.print_profile_summary()
|
|
308
|
+
log.info("ModelStock merging completed successfully")
|
|
309
|
+
return result_model
|
|
@@ -9,6 +9,7 @@ from torch.nn.modules import Module
|
|
|
9
9
|
from torch.utils.data import DataLoader
|
|
10
10
|
from tqdm.autonotebook import tqdm
|
|
11
11
|
|
|
12
|
+
from fusion_bench import auto_register_config
|
|
12
13
|
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
13
14
|
from fusion_bench.mixins import CLIPClassificationMixin
|
|
14
15
|
|
|
@@ -17,17 +18,13 @@ from .regmean import RegMeanAlgorithm
|
|
|
17
18
|
log = logging.getLogger(__name__)
|
|
18
19
|
|
|
19
20
|
|
|
21
|
+
@auto_register_config
|
|
20
22
|
class RegMeanAlgorithmForCLIP(
|
|
21
|
-
RegMeanAlgorithm,
|
|
22
23
|
CLIPClassificationMixin,
|
|
24
|
+
RegMeanAlgorithm,
|
|
23
25
|
):
|
|
24
|
-
_config_mapping = {
|
|
25
|
-
"_dataloader_kwargs": "dataloader_kwargs",
|
|
26
|
-
}
|
|
27
|
-
|
|
28
26
|
def __init__(self, *, dataloader_kwargs: DictConfig, **kwargs):
|
|
29
27
|
super().__init__(**kwargs)
|
|
30
|
-
self.dataloader_kwargs = dataloader_kwargs
|
|
31
28
|
|
|
32
29
|
def on_regmean_start(self):
|
|
33
30
|
self.setup_zero_shot_classification_head()
|