fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
|
@@ -22,6 +22,7 @@ from torch.utils.data import DataLoader
|
|
|
22
22
|
from tqdm.auto import tqdm
|
|
23
23
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
24
24
|
|
|
25
|
+
from fusion_bench import cache_with_joblib
|
|
25
26
|
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
26
27
|
from fusion_bench.mixins import LightningFabricMixin
|
|
27
28
|
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
@@ -46,7 +47,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
46
47
|
|
|
47
48
|
- `_dataloader_kwargs` (Dict[str, Any]): Keyword arguments for the dataloader.
|
|
48
49
|
- `modelpool` (CLIPVisionModelPool): The model pool containing the CLIP models.
|
|
49
|
-
- `zeroshot_weights_cache_dir` (Optional[str]): The directory to cache the zero-shot weights.
|
|
50
50
|
"""
|
|
51
51
|
|
|
52
52
|
dataloader_kwargs: Dict[str, Any] = {}
|
|
@@ -54,7 +54,6 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
54
54
|
modelpool: CLIPVisionModelPool = None
|
|
55
55
|
_clip_processor: CLIPProcessor = None
|
|
56
56
|
# a dict of zeroshot weights for each task, each key is the task name
|
|
57
|
-
zeroshot_weights_cache_dir: str = "outputs/cache/clip_zeroshot_weights"
|
|
58
57
|
zeroshot_weights: Dict[str, torch.Tensor] = {}
|
|
59
58
|
whether_setup_zero_shot_classification_head = False
|
|
60
59
|
|
|
@@ -114,11 +113,27 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
114
113
|
clip_model: Optional[CLIPModel] = None,
|
|
115
114
|
task_names: Optional[List[str]] = None,
|
|
116
115
|
):
|
|
116
|
+
"""
|
|
117
|
+
Initializes a zero-shot classification head.
|
|
118
|
+
|
|
119
|
+
This method constructs a zero-shot classification head by generating text embeddings for each class name using a set of templates.
|
|
120
|
+
These embeddings function as the weights of the classification layer. The method also extracts the `visual_projection` and `logit_scale`
|
|
121
|
+
from the provided CLIP model, which are necessary for calculating the final logits.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
clip_processor (Optional[CLIPProcessor]): The processor for the CLIP model. If not provided, it is loaded from the model pool.
|
|
125
|
+
clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
|
|
126
|
+
task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
|
|
127
|
+
"""
|
|
117
128
|
self.whether_setup_zero_shot_classification_head = True
|
|
129
|
+
# load clip model if not provided
|
|
118
130
|
if clip_model is None:
|
|
119
131
|
if self.modelpool.has_pretrained:
|
|
120
132
|
clip_model = self.modelpool.load_clip_model("_pretrained_")
|
|
121
133
|
else:
|
|
134
|
+
log.warning(
|
|
135
|
+
f"No pretrained CLIP model found, using the model from the model pool: {self.modelpool.model_names[0]}."
|
|
136
|
+
)
|
|
122
137
|
clip_model = self.modelpool.load_clip_model(
|
|
123
138
|
self.modelpool.model_names[0]
|
|
124
139
|
)
|
|
@@ -131,26 +146,16 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
131
146
|
self.visual_projection = self.fabric.to_device(self.visual_projection)
|
|
132
147
|
self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
|
|
133
148
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
cache_dir = os.path.join(
|
|
144
|
-
self.zeroshot_weights_cache_dir,
|
|
145
|
-
os.path.normpath(model_name.split("/")[-1]),
|
|
146
|
-
)
|
|
147
|
-
if not os.path.exists(cache_dir):
|
|
148
|
-
log.info(
|
|
149
|
-
f"Creating cache directory for zero-shot classification head at {cache_dir}"
|
|
150
|
-
)
|
|
151
|
-
os.makedirs(cache_dir)
|
|
149
|
+
@cache_with_joblib()
|
|
150
|
+
def construct_classification_head(task: str):
|
|
151
|
+
nonlocal clip_classifier
|
|
152
|
+
|
|
153
|
+
classnames, templates = get_classnames_and_templates(task)
|
|
154
|
+
clip_classifier.set_classification_task(classnames, templates)
|
|
155
|
+
zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
|
|
156
|
+
|
|
157
|
+
return zeroshot_weights
|
|
152
158
|
|
|
153
|
-
log.info(f"cache directory for zero-shot classification head: {cache_dir}")
|
|
154
159
|
for task in tqdm(
|
|
155
160
|
self.modelpool.model_names if task_names is None else task_names,
|
|
156
161
|
"Setting up zero-shot classification head",
|
|
@@ -158,27 +163,7 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
158
163
|
):
|
|
159
164
|
zeroshot_weights = None
|
|
160
165
|
if self.fabric.is_global_zero:
|
|
161
|
-
|
|
162
|
-
cache_dir, os.path.normpath(f"{task}_zeroshot_weights.pt")
|
|
163
|
-
)
|
|
164
|
-
if os.path.exists(cache_file):
|
|
165
|
-
zeroshot_weights = torch.load(
|
|
166
|
-
cache_file,
|
|
167
|
-
map_location="cpu",
|
|
168
|
-
weights_only=True,
|
|
169
|
-
).detach()
|
|
170
|
-
log.info(
|
|
171
|
-
f"Loadded cached zeroshot weights for task: {task}, shape: {zeroshot_weights.shape}"
|
|
172
|
-
)
|
|
173
|
-
else:
|
|
174
|
-
log.info(
|
|
175
|
-
f"Construct zero shot classification head for task: {task}"
|
|
176
|
-
)
|
|
177
|
-
classnames, templates = get_classnames_and_templates(task)
|
|
178
|
-
clip_classifier.set_classification_task(classnames, templates)
|
|
179
|
-
zeroshot_weights = clip_classifier.zeroshot_weights.detach().clone()
|
|
180
|
-
log.info(f"save zeroshot weights to {cache_file}")
|
|
181
|
-
torch.save(zeroshot_weights, cache_file)
|
|
166
|
+
zeroshot_weights = construct_classification_head(task)
|
|
182
167
|
|
|
183
168
|
self.fabric.barrier()
|
|
184
169
|
self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
|
|
@@ -197,16 +182,20 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
197
182
|
image_embeds: Optional[torch.Tensor] = None,
|
|
198
183
|
) -> torch.Tensor:
|
|
199
184
|
"""
|
|
200
|
-
|
|
185
|
+
Computes the classification logits for a batch of images for a specific task.
|
|
186
|
+
|
|
187
|
+
This method performs zero-shot classification by calculating the cosine similarity between image and text embeddings.
|
|
188
|
+
The image embeddings are obtained from the provided vision model, and the text embeddings (zero-shot weights) are pre-computed for the task.
|
|
189
|
+
The similarity scores are then scaled by the CLIP model's `logit_scale` to produce the final logits.
|
|
201
190
|
|
|
202
191
|
Args:
|
|
203
|
-
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The
|
|
204
|
-
images (torch.Tensor):
|
|
205
|
-
task (str): The
|
|
206
|
-
image_embeds (Optional[torch.Tensor]):
|
|
192
|
+
module (Union[nn.Module, CLIPVisionModel, "CLIPVisionTransformer"]): The vision encoder part of the CLIP model.
|
|
193
|
+
images (torch.Tensor): A batch of images to classify.
|
|
194
|
+
task (str): The name of the classification task.
|
|
195
|
+
image_embeds (Optional[torch.Tensor]): Pre-computed image embeddings. If provided, the method skips the image encoding step.
|
|
207
196
|
|
|
208
197
|
Returns:
|
|
209
|
-
torch.Tensor:
|
|
198
|
+
torch.Tensor: A tensor of logits for each image, with shape (batch_size, num_classes).
|
|
210
199
|
"""
|
|
211
200
|
text_embeds = self.zeroshot_weights[task]
|
|
212
201
|
|
|
@@ -4,7 +4,7 @@ from copy import deepcopy
|
|
|
4
4
|
from functools import wraps
|
|
5
5
|
from inspect import Parameter, _ParameterKind
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Dict, Optional, Union
|
|
7
|
+
from typing import Dict, Mapping, Optional, Union
|
|
8
8
|
|
|
9
9
|
from omegaconf import DictConfig, OmegaConf
|
|
10
10
|
|
|
@@ -21,6 +21,20 @@ __all__ = [
|
|
|
21
21
|
]
|
|
22
22
|
|
|
23
23
|
|
|
24
|
+
def _get_attr_name(config_mapping: Mapping[str, str], param_name):
|
|
25
|
+
for attr_name, p in config_mapping.items():
|
|
26
|
+
if p == param_name:
|
|
27
|
+
return attr_name
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError(f"Parameter {param_name} not found in config mapping.")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _set_attr(self, param_name: str, value):
|
|
33
|
+
attr_name = _get_attr_name(self._config_mapping, param_name)
|
|
34
|
+
log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
|
|
35
|
+
setattr(self, attr_name, value)
|
|
36
|
+
|
|
37
|
+
|
|
24
38
|
def auto_register_config(cls):
|
|
25
39
|
"""
|
|
26
40
|
Decorator to automatically register __init__ parameters in _config_mapping.
|
|
@@ -56,8 +70,8 @@ def auto_register_config(cls):
|
|
|
56
70
|
```python
|
|
57
71
|
@auto_register_config
|
|
58
72
|
class MyAlgorithm(BaseYAMLSerializable):
|
|
59
|
-
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
|
|
60
|
-
super().__init__()
|
|
73
|
+
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
|
|
74
|
+
super().__init__(**kwargs)
|
|
61
75
|
|
|
62
76
|
# All instantiation methods work automatically:
|
|
63
77
|
algo1 = MyAlgorithm(0.01, 64) # positional args
|
|
@@ -90,14 +104,20 @@ def auto_register_config(cls):
|
|
|
90
104
|
# Auto-register parameters in _config_mapping
|
|
91
105
|
if not "_config_mapping" in cls.__dict__:
|
|
92
106
|
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
|
|
107
|
+
registered_parameters = tuple(cls._config_mapping.values())
|
|
108
|
+
|
|
93
109
|
for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
|
|
94
|
-
if
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
110
|
+
if (
|
|
111
|
+
sig.parameters[param_name].kind
|
|
112
|
+
not in [
|
|
113
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
114
|
+
_ParameterKind.VAR_KEYWORD,
|
|
115
|
+
]
|
|
116
|
+
) and (param_name not in registered_parameters):
|
|
98
117
|
cls._config_mapping[param_name] = param_name
|
|
99
118
|
|
|
100
119
|
def __init__(self, *args, **kwargs):
|
|
120
|
+
log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
|
|
101
121
|
# auto-register the attributes based on the signature
|
|
102
122
|
sig = inspect.signature(original_init)
|
|
103
123
|
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
|
@@ -110,7 +130,7 @@ def auto_register_config(cls):
|
|
|
110
130
|
_ParameterKind.VAR_POSITIONAL,
|
|
111
131
|
_ParameterKind.VAR_KEYWORD,
|
|
112
132
|
]:
|
|
113
|
-
|
|
133
|
+
_set_attr(self, param_name, arg_value)
|
|
114
134
|
|
|
115
135
|
# Handle keyword arguments and defaults
|
|
116
136
|
for param_name in param_names:
|
|
@@ -124,12 +144,12 @@ def auto_register_config(cls):
|
|
|
124
144
|
continue
|
|
125
145
|
|
|
126
146
|
if param_name in kwargs:
|
|
127
|
-
|
|
147
|
+
_set_attr(self, param_name, kwargs[param_name])
|
|
128
148
|
else:
|
|
129
149
|
# Set default value if available and attribute doesn't exist
|
|
130
150
|
default_value = sig.parameters[param_name].default
|
|
131
151
|
if default_value is not Parameter.empty:
|
|
132
|
-
|
|
152
|
+
_set_attr(self, param_name, default_value)
|
|
133
153
|
|
|
134
154
|
# Call the original __init__
|
|
135
155
|
result = original_init(self, *args, **kwargs)
|
|
@@ -277,7 +277,7 @@ class BaseModelPool(
|
|
|
277
277
|
for dataset_name in self.test_dataset_names:
|
|
278
278
|
yield self.load_test_dataset(dataset_name)
|
|
279
279
|
|
|
280
|
-
def save_model(self, model: nn.Module, path: str):
|
|
280
|
+
def save_model(self, model: nn.Module, path: str, *args, **kwargs):
|
|
281
281
|
"""
|
|
282
282
|
Save the state dictionary of the model to the specified path.
|
|
283
283
|
|