fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
from torch import nn
|
|
7
7
|
|
|
8
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
|
-
from fusion_bench.mixins
|
|
9
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
10
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
11
11
|
from fusion_bench.utils import LazyStateDict
|
|
12
12
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
@@ -59,12 +59,20 @@ def simple_average(
|
|
|
59
59
|
return state_dict_avg(modules)
|
|
60
60
|
|
|
61
61
|
|
|
62
|
+
@auto_register_config
|
|
62
63
|
class SimpleAverageAlgorithm(
|
|
63
64
|
BaseAlgorithm,
|
|
64
65
|
SimpleProfilerMixin,
|
|
65
66
|
):
|
|
67
|
+
def __init__(self, show_pbar: bool = False, **kwargs):
|
|
68
|
+
"""
|
|
69
|
+
Args:
|
|
70
|
+
show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
|
|
71
|
+
"""
|
|
72
|
+
super().__init__(**kwargs)
|
|
73
|
+
|
|
66
74
|
@torch.no_grad()
|
|
67
|
-
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
75
|
+
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
|
|
68
76
|
"""
|
|
69
77
|
Fuse the models in the given model pool using simple averaging.
|
|
70
78
|
|
|
@@ -100,10 +108,14 @@ class SimpleAverageAlgorithm(
|
|
|
100
108
|
forward_model = model
|
|
101
109
|
else:
|
|
102
110
|
# Add the current model's state dictionary to the accumulated state dictionary
|
|
103
|
-
sd = state_dict_add(
|
|
111
|
+
sd = state_dict_add(
|
|
112
|
+
sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
|
|
113
|
+
)
|
|
104
114
|
with self.profile("merge weights"):
|
|
105
115
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
106
|
-
sd = state_dict_div(
|
|
116
|
+
sd = state_dict_div(
|
|
117
|
+
sd, len(modelpool.model_names), show_pbar=self.show_pbar
|
|
118
|
+
)
|
|
107
119
|
|
|
108
120
|
if isinstance(forward_model, LazyStateDict):
|
|
109
121
|
# if the model is a LazyStateDict, convert it to an empty module
|
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
|
+
from typing import Any, Dict
|
|
2
3
|
|
|
3
4
|
import torch
|
|
5
|
+
from torch import nn
|
|
4
6
|
from typing_extensions import override
|
|
5
7
|
|
|
6
8
|
from fusion_bench.method import BaseAlgorithm
|
|
7
9
|
from fusion_bench.modelpool import BaseModelPool
|
|
10
|
+
from fusion_bench.utils.type import StateDictType
|
|
8
11
|
|
|
9
12
|
from .slerp_utils import slerp
|
|
10
13
|
|
|
@@ -18,7 +21,7 @@ def slerp_on_state_dicts(
|
|
|
18
21
|
*,
|
|
19
22
|
DOT_THRESHOLD: float = 0.9995,
|
|
20
23
|
epsilon: float = 1e-8,
|
|
21
|
-
):
|
|
24
|
+
) -> StateDictType:
|
|
22
25
|
"""
|
|
23
26
|
Perform spherical linear interpolation (slerp) on the state dictionaries of two models.
|
|
24
27
|
|
|
@@ -72,7 +75,7 @@ class SlerpMergeAlgorithm(BaseAlgorithm):
|
|
|
72
75
|
super().__init__()
|
|
73
76
|
|
|
74
77
|
@override
|
|
75
|
-
def run(self, modelpool: BaseModelPool):
|
|
78
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
76
79
|
"""
|
|
77
80
|
Run the SlerpMergeAlgorithm on the given model pool.
|
|
78
81
|
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal, cast
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
from transformers import CLIPVisionModel
|
|
11
|
+
|
|
12
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
13
|
+
from fusion_bench.dataset import CLIPDataset
|
|
14
|
+
from fusion_bench.method import SmileUpscalingAlgorithm
|
|
15
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
16
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
17
|
+
from fusion_bench.taskpool.clip_vision.taskpool import LayerWiseFeatureSaver
|
|
18
|
+
from fusion_bench.utils.devices import clear_cuda_cache
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@auto_register_config
|
|
22
|
+
class LowRankApproximation(BaseAlgorithm):
|
|
23
|
+
def __init__(self, rank: int, device: str = "cuda", **kwargs):
|
|
24
|
+
"""Low-rank approximation of fine-tuned updates."""
|
|
25
|
+
super().__init__(**kwargs)
|
|
26
|
+
|
|
27
|
+
def run(self, modelpool: BaseModelPool):
|
|
28
|
+
# Implement low-rank approximation logic here
|
|
29
|
+
base_model = modelpool.load_pretrained_model()
|
|
30
|
+
|
|
31
|
+
models = {}
|
|
32
|
+
for model_name in tqdm(modelpool.model_names, "processing models"):
|
|
33
|
+
task_model = modelpool.load_model(model_name)
|
|
34
|
+
for module_name, module in task_model.named_modules():
|
|
35
|
+
if isinstance(module, nn.Linear):
|
|
36
|
+
w = cast(
|
|
37
|
+
nn.Linear, base_model.get_submodule(module_name)
|
|
38
|
+
).weight.to(dtype=torch.float32, device=self.device, copy=True)
|
|
39
|
+
w_ft = module.weight.to(
|
|
40
|
+
dtype=torch.float32, device=self.device, copy=True
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Compute low-rank approximation
|
|
44
|
+
w_diff = w_ft - w
|
|
45
|
+
u, s, vh = torch.linalg.svd(w_diff)
|
|
46
|
+
v = vh.T
|
|
47
|
+
|
|
48
|
+
u = u[:, : self.rank]
|
|
49
|
+
s = s[: self.rank]
|
|
50
|
+
v = v[:, : self.rank]
|
|
51
|
+
|
|
52
|
+
low_rank_w_diff = torch.linalg.multi_dot((u, torch.diag(s), v.T))
|
|
53
|
+
low_rank_w = w + low_rank_w_diff
|
|
54
|
+
|
|
55
|
+
module.weight.data = low_rank_w.to(
|
|
56
|
+
dtype=module.weight.dtype,
|
|
57
|
+
device=module.weight.device,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
models[model_name] = task_model
|
|
61
|
+
return models
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@auto_register_config
|
|
65
|
+
class ErrorAccumulationAnalysisForCLIP(
|
|
66
|
+
LightningFabricMixin,
|
|
67
|
+
BaseAlgorithm,
|
|
68
|
+
):
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
gate_k: int,
|
|
72
|
+
k: int,
|
|
73
|
+
seed: int = 42,
|
|
74
|
+
top_k: int = 1,
|
|
75
|
+
dataset_kwargs: DictConfig = None,
|
|
76
|
+
max_samples: int = 1024,
|
|
77
|
+
**kwargs,
|
|
78
|
+
):
|
|
79
|
+
super().__init__(**kwargs)
|
|
80
|
+
if dataset_kwargs is None:
|
|
81
|
+
self.dataset_kwargs = DictConfig(
|
|
82
|
+
{
|
|
83
|
+
"batch_size": 32,
|
|
84
|
+
"num_workers": 4,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def run(self, modelpool: CLIPVisionModelPool):
|
|
89
|
+
assert self.fabric.world_size == 1, "Distributed inference is not supported."
|
|
90
|
+
# get the smile model
|
|
91
|
+
smile_algorithm = SmileUpscalingAlgorithm(
|
|
92
|
+
gate_k=self.gate_k, k=self.k, top_k=self.top_k, device=self.fabric.device
|
|
93
|
+
)
|
|
94
|
+
smile_model = smile_algorithm.run(modelpool)
|
|
95
|
+
# get low-rank models
|
|
96
|
+
low_rank_models = LowRankApproximation(rank=self.k).run(modelpool)
|
|
97
|
+
|
|
98
|
+
results = {
|
|
99
|
+
"model_name": [],
|
|
100
|
+
"method": [],
|
|
101
|
+
"layer_index": [],
|
|
102
|
+
"approximation_error": [],
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
for model_name in modelpool.model_names:
|
|
106
|
+
dataset = modelpool.load_test_dataset(model_name)
|
|
107
|
+
processor = modelpool.load_processor()
|
|
108
|
+
dataset = CLIPDataset(dataset, processor)
|
|
109
|
+
dataloader = DataLoader(dataset, shuffle=True, **self.dataset_kwargs)
|
|
110
|
+
dataloader = self.fabric.setup_dataloaders(dataloader)
|
|
111
|
+
|
|
112
|
+
# finetuned_model
|
|
113
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
114
|
+
finetuned_model = self.to_device(finetuned_model)
|
|
115
|
+
self.collect_hidden_states(
|
|
116
|
+
finetuned_model,
|
|
117
|
+
dataloader=dataloader,
|
|
118
|
+
model_name=f"{model_name}/finetuned",
|
|
119
|
+
)
|
|
120
|
+
del finetuned_model
|
|
121
|
+
clear_cuda_cache()
|
|
122
|
+
|
|
123
|
+
# smile model
|
|
124
|
+
smile_model = self.to_device(smile_model)
|
|
125
|
+
self.collect_hidden_states(
|
|
126
|
+
smile_model, dataloader=dataloader, model_name=f"{model_name}/smile"
|
|
127
|
+
)
|
|
128
|
+
smile_model.cpu()
|
|
129
|
+
clear_cuda_cache()
|
|
130
|
+
|
|
131
|
+
# low-rank models
|
|
132
|
+
model = low_rank_models.pop(model_name)
|
|
133
|
+
model = self.to_device(model)
|
|
134
|
+
self.collect_hidden_states(
|
|
135
|
+
model, dataloader=dataloader, model_name=f"{model_name}/low-rank"
|
|
136
|
+
)
|
|
137
|
+
del model
|
|
138
|
+
clear_cuda_cache()
|
|
139
|
+
|
|
140
|
+
del dataloader
|
|
141
|
+
clear_cuda_cache()
|
|
142
|
+
|
|
143
|
+
@torch.no_grad()
|
|
144
|
+
def collect_hidden_states(
|
|
145
|
+
self, model: CLIPVisionModel, dataloader, model_name: str
|
|
146
|
+
):
|
|
147
|
+
self.fabric.seed_everything(
|
|
148
|
+
self.seed, workers=True
|
|
149
|
+
) # make sure to get same data samples
|
|
150
|
+
# register hooks
|
|
151
|
+
hooks = {}
|
|
152
|
+
hook_handles = {}
|
|
153
|
+
for i, layer in enumerate(model.vision_model.encoder.layers):
|
|
154
|
+
hooks[i] = LayerWiseFeatureSaver(
|
|
155
|
+
save_path=os.path.join(self.log_dir, model_name, f"layer_{i}.pth"),
|
|
156
|
+
first_token_only=True,
|
|
157
|
+
)
|
|
158
|
+
hook_handles[i] = layer.register_forward_hook(hooks[i])
|
|
159
|
+
|
|
160
|
+
# forward pass
|
|
161
|
+
num_total_samples = 0
|
|
162
|
+
for images, _ in tqdm(dataloader, desc=f"Collecting features for {model_name}"):
|
|
163
|
+
batch_size = images.size(0)
|
|
164
|
+
model(images)
|
|
165
|
+
num_total_samples += batch_size
|
|
166
|
+
if num_total_samples >= self.max_samples:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
# save features
|
|
170
|
+
for i, hook in hooks.items():
|
|
171
|
+
hook.save_features()
|
|
172
|
+
|
|
173
|
+
# remove hooks
|
|
174
|
+
for i, hook_handle in hook_handles.items():
|
|
175
|
+
hook_handle.remove()
|
|
176
|
+
|
|
177
|
+
return hooks
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
8
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
9
|
+
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProjectedEnergyAnalysis(
|
|
14
|
+
SimpleProfilerMixin,
|
|
15
|
+
LightningFabricMixin,
|
|
16
|
+
BaseAlgorithm,
|
|
17
|
+
):
|
|
18
|
+
def on_run_start(self):
|
|
19
|
+
self.device = self.fabric.device
|
|
20
|
+
|
|
21
|
+
def run(self, modelpool: BaseModelPool):
|
|
22
|
+
with self.profile("model loading"):
|
|
23
|
+
base_model = modelpool.load_pretrained_model()
|
|
24
|
+
|
|
25
|
+
results = {
|
|
26
|
+
"model_name": [],
|
|
27
|
+
"module_index": [],
|
|
28
|
+
"module_name": [],
|
|
29
|
+
"projected_energy_I": [],
|
|
30
|
+
"projected_energy_II": [],
|
|
31
|
+
"projected_energy_II_III": [],
|
|
32
|
+
}
|
|
33
|
+
for model_name in tqdm(
|
|
34
|
+
modelpool.model_names,
|
|
35
|
+
"analyzing",
|
|
36
|
+
dynamic_ncols=True,
|
|
37
|
+
):
|
|
38
|
+
with self.profile("model loading"):
|
|
39
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
40
|
+
|
|
41
|
+
module_index = 0
|
|
42
|
+
for module_name, base_module in tqdm(
|
|
43
|
+
list(base_model.named_modules()),
|
|
44
|
+
"analyzing modules",
|
|
45
|
+
dynamic_ncols=True,
|
|
46
|
+
):
|
|
47
|
+
if isinstance(base_module, torch.nn.Linear):
|
|
48
|
+
with self.profile("weight analysis"):
|
|
49
|
+
_result = self.analyze_weight(
|
|
50
|
+
base_module.weight,
|
|
51
|
+
finetuned_model.get_submodule(module_name).weight,
|
|
52
|
+
)
|
|
53
|
+
results["model_name"].append(model_name)
|
|
54
|
+
results["module_index"].append(module_index)
|
|
55
|
+
results["module_name"].append(module_name)
|
|
56
|
+
for key, value in _result.items():
|
|
57
|
+
results[key].append(value)
|
|
58
|
+
|
|
59
|
+
module_index += 1
|
|
60
|
+
|
|
61
|
+
# save results as csv
|
|
62
|
+
results = pd.DataFrame(results)
|
|
63
|
+
results.to_csv(
|
|
64
|
+
os.path.join(self.log_dir, "projected_energy_analysis.csv"), index=True
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.print_profile_summary()
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
@torch.no_grad()
|
|
71
|
+
def analyze_weight(self, w: torch.Tensor, w_ft: torch.Tensor, k: int = -1):
|
|
72
|
+
w = w.to(dtype=torch.float32, device=self.device)
|
|
73
|
+
w_ft = w_ft.to(dtype=torch.float32, device=self.device)
|
|
74
|
+
w_diff = w_ft - w
|
|
75
|
+
|
|
76
|
+
# Perform analysis on the weight tensor
|
|
77
|
+
u, s, vh = torch.linalg.svd(w, full_matrices=False)
|
|
78
|
+
v = vh.T
|
|
79
|
+
if k < 0:
|
|
80
|
+
# find the position where the sum of singular values is larger than 50% of the total sum
|
|
81
|
+
cumsum = s.cumsum(0)
|
|
82
|
+
k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1
|
|
83
|
+
|
|
84
|
+
# subspace I
|
|
85
|
+
w_diff_proj = self._project_subspace_low(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
86
|
+
projected_energy_I = (
|
|
87
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
88
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# subspace II
|
|
92
|
+
w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
93
|
+
projected_energy_II = (
|
|
94
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
95
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
## subspace II+III
|
|
99
|
+
u, s, vh = torch.linalg.svd(w, full_matrices=True)
|
|
100
|
+
v = vh.T
|
|
101
|
+
w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
102
|
+
projected_energy_II_III = (
|
|
103
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
104
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"projected_energy_I": projected_energy_I.item(),
|
|
109
|
+
"projected_energy_II": projected_energy_II.item(),
|
|
110
|
+
"projected_energy_II_III": projected_energy_II_III.item(),
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def _project_subspace_low(
|
|
114
|
+
self,
|
|
115
|
+
u: torch.Tensor,
|
|
116
|
+
s: torch.Tensor,
|
|
117
|
+
v: torch.Tensor,
|
|
118
|
+
k: int,
|
|
119
|
+
w: torch.Tensor,
|
|
120
|
+
w_ft: torch.Tensor,
|
|
121
|
+
):
|
|
122
|
+
u = u[:, :k]
|
|
123
|
+
s = s[:k]
|
|
124
|
+
v = v[:, :k]
|
|
125
|
+
|
|
126
|
+
w_diff = w_ft - w
|
|
127
|
+
w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
|
|
128
|
+
return w_diff_proj
|
|
129
|
+
|
|
130
|
+
def _project_subspace_high(
|
|
131
|
+
self,
|
|
132
|
+
u: torch.Tensor,
|
|
133
|
+
s: torch.Tensor,
|
|
134
|
+
v: torch.Tensor,
|
|
135
|
+
k: int,
|
|
136
|
+
w: torch.Tensor,
|
|
137
|
+
w_ft: torch.Tensor,
|
|
138
|
+
):
|
|
139
|
+
u = u[:, k:]
|
|
140
|
+
s = s[k:]
|
|
141
|
+
v = v[:, k:]
|
|
142
|
+
|
|
143
|
+
w_diff = w_ft - w
|
|
144
|
+
w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
|
|
145
|
+
return w_diff_proj
|
|
@@ -16,10 +16,16 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
18
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
19
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
20
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
21
|
+
from fusion_bench.models.hf_utils import (
|
|
22
|
+
generate_complete_readme,
|
|
23
|
+
save_pretrained_with_remote_code,
|
|
24
|
+
)
|
|
20
25
|
from fusion_bench.models.modeling_smile_qwen2 import (
|
|
21
26
|
SmileQwen2Config,
|
|
22
27
|
SmileQwen2ForCausalLM,
|
|
28
|
+
SmileQwen2Model,
|
|
23
29
|
)
|
|
24
30
|
from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
|
|
25
31
|
SmileQwen2DecoderLayer,
|
|
@@ -34,6 +40,7 @@ from fusion_bench.utils.parameters import print_parameters
|
|
|
34
40
|
log = logging.getLogger(__name__)
|
|
35
41
|
|
|
36
42
|
|
|
43
|
+
@auto_register_config
|
|
37
44
|
class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
38
45
|
R"""
|
|
39
46
|
SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
|
|
@@ -49,15 +56,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
49
56
|
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
50
57
|
"""
|
|
51
58
|
|
|
52
|
-
|
|
53
|
-
"device": "device",
|
|
54
|
-
"accelerator": "accelerator",
|
|
55
|
-
"model_path": "model_path",
|
|
56
|
-
"model_dtype": "model_dtype",
|
|
57
|
-
"num_experts_per_tok": "num_experts_per_tok",
|
|
58
|
-
"rank_of_router": "rank_of_router",
|
|
59
|
-
"rank_of_expert": "rank_of_expert",
|
|
60
|
-
}
|
|
59
|
+
modelpool: CausalLMPool
|
|
61
60
|
|
|
62
61
|
def __init__(
|
|
63
62
|
self,
|
|
@@ -68,20 +67,13 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
68
67
|
num_experts_per_tok,
|
|
69
68
|
rank_of_router,
|
|
70
69
|
rank_of_expert,
|
|
70
|
+
save_with_remote_code: bool = True,
|
|
71
71
|
**kwargs,
|
|
72
72
|
):
|
|
73
|
-
self.device = device
|
|
74
|
-
self.accelerator = accelerator
|
|
75
|
-
self.model_path = model_path
|
|
76
|
-
self.model_dtype = model_dtype
|
|
77
|
-
# SmileMoE parameters, except `num_local_experts` which is set later according to the number of finetuned models
|
|
78
|
-
self.num_experts_per_tok = num_experts_per_tok
|
|
79
|
-
self.rank_of_router = rank_of_router
|
|
80
|
-
self.rank_of_expert = rank_of_expert
|
|
81
73
|
super().__init__(**kwargs)
|
|
82
74
|
|
|
83
75
|
@torch.no_grad()
|
|
84
|
-
def run(self, modelpool
|
|
76
|
+
def run(self, modelpool) -> SmileQwen2ForCausalLM:
|
|
85
77
|
"""
|
|
86
78
|
Executes the upscaling process.
|
|
87
79
|
|
|
@@ -129,13 +121,29 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
129
121
|
if os.path.dirname(config.model_path):
|
|
130
122
|
os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
|
|
131
123
|
log.info(f"Saving model to {config.model_path}")
|
|
132
|
-
|
|
133
|
-
pretrained_path = pretrained_model_config.get(
|
|
134
|
-
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
135
|
-
)
|
|
136
|
-
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
|
124
|
+
tokenizer = self.modelpool.load_tokenizer()
|
|
137
125
|
tokenizer.save_pretrained(config.model_path)
|
|
138
|
-
|
|
126
|
+
if not self.save_with_remote_code:
|
|
127
|
+
model.save_pretrained(config.model_path)
|
|
128
|
+
else:
|
|
129
|
+
save_pretrained_with_remote_code(
|
|
130
|
+
model,
|
|
131
|
+
auto_map={
|
|
132
|
+
"AutoConfig": SmileQwen2Config,
|
|
133
|
+
"AutoModel": SmileQwen2Model,
|
|
134
|
+
"AutoModelForCausalLM": SmileQwen2ForCausalLM,
|
|
135
|
+
},
|
|
136
|
+
save_directory=config.model_path,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# save readme
|
|
140
|
+
complete_readme = generate_complete_readme(
|
|
141
|
+
algorithm=self,
|
|
142
|
+
modelpool=modelpool,
|
|
143
|
+
models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
|
|
144
|
+
)
|
|
145
|
+
with open(os.path.join(config.model_path, "README.md"), "w") as f:
|
|
146
|
+
f.write(complete_readme)
|
|
139
147
|
|
|
140
148
|
return model
|
|
141
149
|
|
|
@@ -158,9 +166,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
158
166
|
|
|
159
167
|
with init_empty_weights():
|
|
160
168
|
pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
169
|
+
if isinstance(pretrained_model_config, str):
|
|
170
|
+
pretrained_path = pretrained_model_config
|
|
171
|
+
else:
|
|
172
|
+
pretrained_path = pretrained_model_config.get(
|
|
173
|
+
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
174
|
+
)
|
|
164
175
|
base_config = AutoConfig.from_pretrained(pretrained_path)
|
|
165
176
|
model_config = SmileQwen2Config(
|
|
166
177
|
num_experts_per_tok=config.num_experts_per_tok,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import Dict, List, Tuple # noqa: F401
|
|
4
|
+
from typing import Any, Dict, List, Tuple # noqa: F401
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn.functional as F
|
|
@@ -21,6 +21,7 @@ from fusion_bench.models.smile_moe.linear_from_module import (
|
|
|
21
21
|
)
|
|
22
22
|
from fusion_bench.models.utils import get_attr, set_attr
|
|
23
23
|
from fusion_bench.utils.parameters import print_parameters
|
|
24
|
+
from fusion_bench.utils.devices import get_device
|
|
24
25
|
|
|
25
26
|
log = logging.getLogger(__name__)
|
|
26
27
|
|
|
@@ -54,7 +55,7 @@ class SmileUpscalingAlgorithm(
|
|
|
54
55
|
routing_use_diff: bool = True,
|
|
55
56
|
average_experts: bool = False,
|
|
56
57
|
model_path: str = None,
|
|
57
|
-
**kwargs,
|
|
58
|
+
**kwargs: Any,
|
|
58
59
|
):
|
|
59
60
|
"""
|
|
60
61
|
Initialize the SmileUpscalingAlgorithm.
|
|
@@ -91,7 +92,7 @@ class SmileUpscalingAlgorithm(
|
|
|
91
92
|
print(f"=== Config for `{type(self).__name__}` ===")
|
|
92
93
|
|
|
93
94
|
@torch.no_grad()
|
|
94
|
-
def run(self, modelpool: BaseModelPool):
|
|
95
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
95
96
|
"""
|
|
96
97
|
Executes the upscaling process.
|
|
97
98
|
|
|
@@ -142,7 +143,7 @@ class SmileUpscalingAlgorithm(
|
|
|
142
143
|
pretrained_model: nn.Module,
|
|
143
144
|
finetuned_models: List[nn.Module],
|
|
144
145
|
in_place: bool = True,
|
|
145
|
-
):
|
|
146
|
+
) -> nn.Module:
|
|
146
147
|
"""
|
|
147
148
|
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
148
149
|
|
|
@@ -180,7 +181,12 @@ class SmileUpscalingAlgorithm(
|
|
|
180
181
|
|
|
181
182
|
name_list = name.split(".")
|
|
182
183
|
module = get_attr(pretrained_model, name_list)
|
|
183
|
-
|
|
184
|
+
original_device = get_device(module)
|
|
185
|
+
module = module.to(self.device, non_blocking=True)
|
|
186
|
+
experts = [
|
|
187
|
+
get_attr(m, name_list).to(self.device, non_blocking=True)
|
|
188
|
+
for m in finetuned_models
|
|
189
|
+
]
|
|
184
190
|
try:
|
|
185
191
|
moe_linear = SmileMoELinear(
|
|
186
192
|
module,
|
|
@@ -192,6 +198,7 @@ class SmileUpscalingAlgorithm(
|
|
|
192
198
|
full_matrices=self.full_matrices,
|
|
193
199
|
upscaling_accelerator=self.upscaling_accelerator,
|
|
194
200
|
)
|
|
201
|
+
moe_linear = moe_linear.to(original_device, non_blocking=True)
|
|
195
202
|
except ExpertNotTrainedError:
|
|
196
203
|
print(f"skip {name} because the experts are not trained.")
|
|
197
204
|
return
|
|
@@ -9,7 +9,7 @@ from copy import deepcopy
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from fusion_bench import BaseAlgorithm
|
|
12
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
12
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
13
13
|
from fusion_bench.modelpool import BaseModelPool
|
|
14
14
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
15
15
|
state_dict_add,
|
|
@@ -58,16 +58,11 @@ def generate_task_masks(
|
|
|
58
58
|
return final_mask
|
|
59
59
|
|
|
60
60
|
|
|
61
|
+
@auto_register_config
|
|
61
62
|
class TallMaskTaskArithmeticAlgorithm(
|
|
62
|
-
BaseAlgorithm,
|
|
63
63
|
SimpleProfilerMixin,
|
|
64
|
+
BaseAlgorithm,
|
|
64
65
|
):
|
|
65
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
66
|
-
"tall_mask_lambda": "tall_mask_lambda",
|
|
67
|
-
"debug": "debug",
|
|
68
|
-
"verbose": "verbose",
|
|
69
|
-
}
|
|
70
|
-
|
|
71
66
|
def __init__(
|
|
72
67
|
self,
|
|
73
68
|
tall_mask_lambda: float,
|
|
@@ -76,9 +71,6 @@ class TallMaskTaskArithmeticAlgorithm(
|
|
|
76
71
|
**kwargs,
|
|
77
72
|
):
|
|
78
73
|
super().__init__(**kwargs)
|
|
79
|
-
self.tall_mask_lambda = tall_mask_lambda
|
|
80
|
-
self.debug = debug
|
|
81
|
-
self.verbose = verbose
|
|
82
74
|
|
|
83
75
|
@torch.no_grad()
|
|
84
76
|
def run(self, modelpool: BaseModelPool):
|
|
@@ -12,7 +12,7 @@ import torch
|
|
|
12
12
|
from torch import nn
|
|
13
13
|
|
|
14
14
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
15
|
-
from fusion_bench.mixins
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
16
16
|
from fusion_bench.modelpool import BaseModelPool
|
|
17
17
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
18
18
|
state_dict_add,
|
|
@@ -74,9 +74,10 @@ def task_arithmetic_merge(
|
|
|
74
74
|
return pretrained_model
|
|
75
75
|
|
|
76
76
|
|
|
77
|
+
@auto_register_config
|
|
77
78
|
class TaskArithmeticAlgorithm(
|
|
78
|
-
BaseAlgorithm,
|
|
79
79
|
SimpleProfilerMixin,
|
|
80
|
+
BaseAlgorithm,
|
|
80
81
|
):
|
|
81
82
|
"""
|
|
82
83
|
Task Arithmetic Algorithm for model fusion.
|
|
@@ -89,22 +90,17 @@ class TaskArithmeticAlgorithm(
|
|
|
89
90
|
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
90
91
|
"""
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
"scaling_factor": "scaling_factor"
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
def __init__(self, scaling_factor: int):
|
|
93
|
+
def __init__(self, scaling_factor: int, **kwargs):
|
|
97
94
|
"""
|
|
98
95
|
Initializes the TaskArithmeticAlgorithm with the given scaling factor.
|
|
99
96
|
|
|
100
97
|
Args:
|
|
101
98
|
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
102
99
|
"""
|
|
103
|
-
|
|
104
|
-
super().__init__()
|
|
100
|
+
super().__init__(**kwargs)
|
|
105
101
|
|
|
106
102
|
@torch.no_grad()
|
|
107
|
-
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
103
|
+
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
|
|
108
104
|
"""
|
|
109
105
|
Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
|
|
110
106
|
|