fusion-bench 0.2.18__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/constants/banner.py +12 -0
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/linear/simple_average_for_llama.py +30 -5
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +192 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +365 -0
- fusion_bench/method/simple_average.py +29 -3
- fusion_bench/modelpool/causal_lm/causal_lm.py +37 -6
- fusion_bench/modelpool/clip_vision/modelpool.py +45 -12
- fusion_bench/scripts/cli.py +1 -1
- fusion_bench/tasks/clip_classification/imagenet.py +1008 -2004
- fusion_bench/utils/lazy_state_dict.py +75 -3
- fusion_bench/utils/misc.py +66 -2
- fusion_bench/utils/modelscope.py +146 -0
- fusion_bench/utils/state_dict_arithmetic.py +10 -5
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/METADATA +9 -1
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/RECORD +50 -43
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +11 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.18.dist-info → fusion_bench-0.2.20.dist-info}/top_level.txt +0 -0
|
@@ -8,6 +8,7 @@ from torch import nn
|
|
|
8
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
9
|
from fusion_bench.mixins.simple_profiler import SimpleProfilerMixin
|
|
10
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.utils import LazyStateDict
|
|
11
12
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
12
13
|
state_dict_add,
|
|
13
14
|
state_dict_avg,
|
|
@@ -62,6 +63,18 @@ class SimpleAverageAlgorithm(
|
|
|
62
63
|
BaseAlgorithm,
|
|
63
64
|
SimpleProfilerMixin,
|
|
64
65
|
):
|
|
66
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
67
|
+
"show_pbar": "show_pbar",
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
def __init__(self, show_pbar: bool = False):
|
|
71
|
+
"""
|
|
72
|
+
Args:
|
|
73
|
+
show_pbar (bool): If True, shows a progress bar during model loading and merging. Default is False.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__()
|
|
76
|
+
self.show_pbar = show_pbar
|
|
77
|
+
|
|
65
78
|
@torch.no_grad()
|
|
66
79
|
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
67
80
|
"""
|
|
@@ -99,11 +112,24 @@ class SimpleAverageAlgorithm(
|
|
|
99
112
|
forward_model = model
|
|
100
113
|
else:
|
|
101
114
|
# Add the current model's state dictionary to the accumulated state dictionary
|
|
102
|
-
sd = state_dict_add(
|
|
115
|
+
sd = state_dict_add(
|
|
116
|
+
sd, model.state_dict(keep_vars=True), show_pbar=self.show_pbar
|
|
117
|
+
)
|
|
103
118
|
with self.profile("merge weights"):
|
|
104
119
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
105
|
-
sd = state_dict_div(
|
|
106
|
-
|
|
120
|
+
sd = state_dict_div(
|
|
121
|
+
sd, len(modelpool.model_names), show_pbar=self.show_pbar
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if isinstance(forward_model, LazyStateDict):
|
|
125
|
+
# if the model is a LazyStateDict, convert it to an empty module
|
|
126
|
+
forward_model = forward_model.meta_module.to_empty(
|
|
127
|
+
device=(
|
|
128
|
+
"cpu"
|
|
129
|
+
if forward_model._torch_dtype is None
|
|
130
|
+
else forward_model._torch_dtype
|
|
131
|
+
)
|
|
132
|
+
)
|
|
107
133
|
forward_model.load_state_dict(sd)
|
|
108
134
|
# print profile report and log the merged models
|
|
109
135
|
self.print_profile_summary()
|
|
@@ -22,6 +22,8 @@ from typing_extensions import override
|
|
|
22
22
|
from fusion_bench.modelpool import BaseModelPool
|
|
23
23
|
from fusion_bench.utils import instantiate
|
|
24
24
|
from fusion_bench.utils.dtype import parse_dtype
|
|
25
|
+
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
26
|
+
from fusion_bench.utils.packages import import_object
|
|
25
27
|
|
|
26
28
|
log = logging.getLogger(__name__)
|
|
27
29
|
|
|
@@ -30,6 +32,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
30
32
|
_config_mapping = BaseModelPool._config_mapping | {
|
|
31
33
|
"_tokenizer": "tokenizer",
|
|
32
34
|
"_model_kwargs": "model_kwargs",
|
|
35
|
+
"load_lazy": "load_lazy",
|
|
33
36
|
}
|
|
34
37
|
|
|
35
38
|
def __init__(
|
|
@@ -38,6 +41,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
38
41
|
*,
|
|
39
42
|
tokenizer: Optional[DictConfig],
|
|
40
43
|
model_kwargs: Optional[DictConfig] = None,
|
|
44
|
+
load_lazy: bool = False,
|
|
41
45
|
**kwargs,
|
|
42
46
|
):
|
|
43
47
|
super().__init__(models, **kwargs)
|
|
@@ -51,6 +55,7 @@ class CausalLMPool(BaseModelPool):
|
|
|
51
55
|
self._model_kwargs.torch_dtype = parse_dtype(
|
|
52
56
|
self._model_kwargs.torch_dtype
|
|
53
57
|
)
|
|
58
|
+
self.load_lazy = load_lazy
|
|
54
59
|
|
|
55
60
|
@override
|
|
56
61
|
def load_model(
|
|
@@ -88,21 +93,41 @@ class CausalLMPool(BaseModelPool):
|
|
|
88
93
|
model_kwargs.update(kwargs)
|
|
89
94
|
|
|
90
95
|
if isinstance(model_name_or_config, str):
|
|
96
|
+
# If model_name_or_config is a string, it is the name or the path of the model
|
|
91
97
|
log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
92
98
|
if model_name_or_config in self._models.keys():
|
|
93
99
|
model_config = self._models[model_name_or_config]
|
|
94
100
|
if isinstance(model_config, str):
|
|
95
101
|
# model_config is a string
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
if not self.load_lazy:
|
|
103
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
104
|
+
model_config,
|
|
105
|
+
*args,
|
|
106
|
+
**model_kwargs,
|
|
107
|
+
)
|
|
108
|
+
else:
|
|
109
|
+
# model_config is a string, but we want to use LazyStateDict
|
|
110
|
+
model = LazyStateDict(
|
|
111
|
+
checkpoint=model_config,
|
|
112
|
+
meta_module_class=AutoModelForCausalLM,
|
|
113
|
+
*args,
|
|
114
|
+
**model_kwargs,
|
|
115
|
+
)
|
|
101
116
|
return model
|
|
102
117
|
elif isinstance(model_name_or_config, (DictConfig, Dict)):
|
|
103
118
|
model_config = model_name_or_config
|
|
104
119
|
|
|
105
|
-
|
|
120
|
+
if not self.load_lazy:
|
|
121
|
+
model = instantiate(model_config, *args, **model_kwargs)
|
|
122
|
+
else:
|
|
123
|
+
meta_module_class = model_config.pop("_target_")
|
|
124
|
+
checkpoint = model_config.pop("pretrained_model_name_or_path")
|
|
125
|
+
model = LazyStateDict(
|
|
126
|
+
checkpoint=checkpoint,
|
|
127
|
+
meta_module_class=meta_module_class,
|
|
128
|
+
*args,
|
|
129
|
+
**model_kwargs,
|
|
130
|
+
)
|
|
106
131
|
return model
|
|
107
132
|
|
|
108
133
|
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
|
|
@@ -179,6 +204,12 @@ class CausalLMBackbonePool(CausalLMPool):
|
|
|
179
204
|
def load_model(
|
|
180
205
|
self, model_name_or_config: str | DictConfig, *args, **kwargs
|
|
181
206
|
) -> Module:
|
|
207
|
+
if self.load_lazy:
|
|
208
|
+
log.warning(
|
|
209
|
+
"CausalLMBackbonePool does not support lazy loading. "
|
|
210
|
+
"Falling back to normal loading."
|
|
211
|
+
)
|
|
212
|
+
self.load_lazy = False
|
|
182
213
|
model: AutoModelForCausalLM = super().load_model(
|
|
183
214
|
model_name_or_config, *args, **kwargs
|
|
184
215
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Optional, Union
|
|
3
|
+
from typing import Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
6
|
from lightning.fabric.utilities import rank_zero_only
|
|
@@ -11,6 +11,9 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
13
|
from fusion_bench.utils import instantiate, timeit_context
|
|
14
|
+
from fusion_bench.utils.modelscope import (
|
|
15
|
+
resolve_repo_path,
|
|
16
|
+
)
|
|
14
17
|
|
|
15
18
|
from ..base_pool import BaseModelPool
|
|
16
19
|
|
|
@@ -25,25 +28,32 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
25
28
|
the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
|
-
_config_mapping = BaseModelPool._config_mapping | {
|
|
31
|
+
_config_mapping = BaseModelPool._config_mapping | {
|
|
32
|
+
"_processor": "processor",
|
|
33
|
+
"_platform": "hf",
|
|
34
|
+
}
|
|
29
35
|
|
|
30
36
|
def __init__(
|
|
31
37
|
self,
|
|
32
38
|
models: DictConfig,
|
|
33
39
|
*,
|
|
34
40
|
processor: Optional[DictConfig] = None,
|
|
41
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
35
42
|
**kwargs,
|
|
36
43
|
):
|
|
37
44
|
super().__init__(models, **kwargs)
|
|
38
|
-
|
|
39
45
|
self._processor = processor
|
|
46
|
+
self._platform = platform
|
|
40
47
|
|
|
41
48
|
def load_processor(self, *args, **kwargs) -> CLIPProcessor:
|
|
42
49
|
assert self._processor is not None, "Processor is not defined in the config"
|
|
43
50
|
if isinstance(self._processor, str):
|
|
44
51
|
if rank_zero_only.rank == 0:
|
|
45
52
|
log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
|
|
46
|
-
|
|
53
|
+
repo_path = resolve_repo_path(
|
|
54
|
+
repo_id=self._processor, repo_type="model", platform=self._platform
|
|
55
|
+
)
|
|
56
|
+
processor = CLIPProcessor.from_pretrained(repo_path, *args, **kwargs)
|
|
47
57
|
else:
|
|
48
58
|
processor = instantiate(self._processor, *args, **kwargs)
|
|
49
59
|
return processor
|
|
@@ -54,7 +64,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
54
64
|
if isinstance(model_config, str):
|
|
55
65
|
if rank_zero_only.rank == 0:
|
|
56
66
|
log.info(f"Loading `transformers.CLIPModel`: {model_config}")
|
|
57
|
-
|
|
67
|
+
repo_path = resolve_repo_path(
|
|
68
|
+
repo_id=model_config, repo_type="model", platform=self._platform
|
|
69
|
+
)
|
|
70
|
+
clip_model = CLIPModel.from_pretrained(repo_path, *args, **kwargs)
|
|
58
71
|
return clip_model
|
|
59
72
|
else:
|
|
60
73
|
assert isinstance(
|
|
@@ -107,14 +120,17 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
107
120
|
if isinstance(model, str):
|
|
108
121
|
if rank_zero_only.rank == 0:
|
|
109
122
|
log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
|
|
110
|
-
|
|
123
|
+
repo_path = resolve_repo_path(
|
|
124
|
+
model, repo_type="model", platform=self._platform
|
|
125
|
+
)
|
|
126
|
+
return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
|
|
111
127
|
if isinstance(model, nn.Module):
|
|
112
128
|
if rank_zero_only.rank == 0:
|
|
113
129
|
log.info(f"Returning existing model: {model}")
|
|
114
130
|
return model
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
131
|
+
else:
|
|
132
|
+
# If the model is not a string, we use the default load_model method
|
|
133
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
118
134
|
|
|
119
135
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
120
136
|
dataset_config = self._train_datasets[dataset_name]
|
|
@@ -123,7 +139,7 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
123
139
|
log.info(
|
|
124
140
|
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
125
141
|
)
|
|
126
|
-
dataset =
|
|
142
|
+
dataset = self._load_dataset(dataset_config, split="train")
|
|
127
143
|
else:
|
|
128
144
|
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
129
145
|
return dataset
|
|
@@ -135,7 +151,7 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
135
151
|
log.info(
|
|
136
152
|
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
137
153
|
)
|
|
138
|
-
dataset =
|
|
154
|
+
dataset = self._load_dataset(dataset_config, split="validation")
|
|
139
155
|
else:
|
|
140
156
|
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
141
157
|
return dataset
|
|
@@ -147,7 +163,24 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
147
163
|
log.info(
|
|
148
164
|
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
149
165
|
)
|
|
150
|
-
dataset =
|
|
166
|
+
dataset = self._load_dataset(dataset_config, split="test")
|
|
151
167
|
else:
|
|
152
168
|
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
153
169
|
return dataset
|
|
170
|
+
|
|
171
|
+
def _load_dataset(self, name: str, split: str):
|
|
172
|
+
"""
|
|
173
|
+
Load a dataset by its name and split.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
dataset_name (str): The name of the dataset.
|
|
177
|
+
split (str): The split of the dataset to load (e.g., "train", "validation", "test").
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dataset: The loaded dataset.
|
|
181
|
+
"""
|
|
182
|
+
datset_dir = resolve_repo_path(
|
|
183
|
+
name, repo_type="dataset", platform=self._platform
|
|
184
|
+
)
|
|
185
|
+
dataset = load_dataset(datset_dir, split=split)
|
|
186
|
+
return dataset
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
"""
|
|
3
|
-
This is the CLI script that is executed when the user runs the `
|
|
3
|
+
This is the CLI script that is executed when the user runs the `fusion_bench` command.
|
|
4
4
|
The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
|
|
5
5
|
"""
|
|
6
6
|
|