fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +40 -83
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +10 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +36 -11
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +85 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/rich_utils.py +7 -3
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
|
@@ -45,21 +45,21 @@ def linearize_lora_model_(model):
|
|
|
45
45
|
|
|
46
46
|
|
|
47
47
|
def load_fft_vision_model_hf(
|
|
48
|
-
model_name: str,
|
|
48
|
+
model_name: str, return_vision_model=True
|
|
49
49
|
) -> Union[CLIPVisionTransformer, CLIPVisionModel]:
|
|
50
50
|
"""
|
|
51
51
|
Load a CLIP vision model from Hugging Face.
|
|
52
52
|
|
|
53
53
|
Args:
|
|
54
54
|
model_name (str): The name of the CLIP vision model to load from Hugging Face.
|
|
55
|
-
|
|
55
|
+
return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
|
|
56
56
|
|
|
57
57
|
Returns:
|
|
58
58
|
Union[CLIPVisionTransformer, CLIPVisionModel]: The vision model.
|
|
59
59
|
"""
|
|
60
60
|
model = CLIPVisionModel.from_pretrained(model_name)
|
|
61
61
|
|
|
62
|
-
if
|
|
62
|
+
if return_vision_model:
|
|
63
63
|
return CLIPVisionModel.from_pretrained(model_name).vision_model
|
|
64
64
|
else:
|
|
65
65
|
return model
|
|
@@ -69,7 +69,7 @@ def load_lora_vision_model_hf(
|
|
|
69
69
|
base_model_name: str,
|
|
70
70
|
peft_name: str,
|
|
71
71
|
merge_and_unload: bool = False,
|
|
72
|
-
|
|
72
|
+
return_vision_model=True,
|
|
73
73
|
) -> PeftModel:
|
|
74
74
|
"""
|
|
75
75
|
Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
|
|
@@ -80,7 +80,7 @@ def load_lora_vision_model_hf(
|
|
|
80
80
|
base_model_name (str): The name of the base vision model to load from Hugging Face.
|
|
81
81
|
peft_name (str): The name of the LoRA adaptation to apply to the base model.
|
|
82
82
|
merge_and_unload (bool, optional): If True, the LoRA adaptation is merged into the base model and the LoRA layers are removed. Defaults to False.
|
|
83
|
-
|
|
83
|
+
return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
|
|
84
84
|
|
|
85
85
|
Returns:
|
|
86
86
|
PeftModel: The adapted vision model, optionally merged and unloaded.
|
|
@@ -97,7 +97,7 @@ def load_lora_vision_model_hf(
|
|
|
97
97
|
vision_model = peft_model
|
|
98
98
|
|
|
99
99
|
# Return the vision model
|
|
100
|
-
if
|
|
100
|
+
if return_vision_model:
|
|
101
101
|
return vision_model
|
|
102
102
|
else:
|
|
103
103
|
model.vision_model = vision_model
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
---
|
|
2
2
|
base_model:
|
|
3
|
+
{%- if base_model is not none %}
|
|
4
|
+
- {{ base_model }}
|
|
5
|
+
{%- endif %}
|
|
3
6
|
{%- for model in models %}
|
|
4
7
|
- {{ model }}
|
|
5
8
|
{%- endfor %}
|
|
@@ -18,7 +21,11 @@ tags:
|
|
|
18
21
|
This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
19
22
|
|
|
20
23
|
The following models were included in the merge:
|
|
21
|
-
|
|
24
|
+
|
|
25
|
+
{% if base_model is not none %}
|
|
26
|
+
- base model: {{ base_model }}
|
|
27
|
+
{%- endif %}
|
|
28
|
+
{%- for model in models %}
|
|
22
29
|
- {{ model }}
|
|
23
30
|
{%- endfor %}
|
|
24
31
|
|
fusion_bench/models/we_moe.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
|
-
from typing import List
|
|
3
|
+
from typing import Generic, List
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.func
|
|
@@ -9,7 +9,7 @@ from torch.func import functional_call
|
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
|
|
11
11
|
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
12
|
-
from fusion_bench.utils.type import StateDictType
|
|
12
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
13
13
|
|
|
14
14
|
log = logging.getLogger(__name__)
|
|
15
15
|
|
|
@@ -76,15 +76,15 @@ def construct_weight_ensembling_gate(
|
|
|
76
76
|
return gate
|
|
77
77
|
|
|
78
78
|
|
|
79
|
-
class WeightEnsemblingMoE(nn.Module):
|
|
79
|
+
class WeightEnsemblingMoE(nn.Module, Generic[TorchModelType]):
|
|
80
80
|
# variable to store the merged state dict temporarily
|
|
81
81
|
_merged_state_dict: StateDictType = None
|
|
82
82
|
|
|
83
83
|
def __init__(
|
|
84
84
|
self,
|
|
85
85
|
hidden_size: int,
|
|
86
|
-
base_model:
|
|
87
|
-
expert_models: List[
|
|
86
|
+
base_model: TorchModelType,
|
|
87
|
+
expert_models: List[TorchModelType],
|
|
88
88
|
init_lambda: float = 0.2,
|
|
89
89
|
batch_first: bool = False,
|
|
90
90
|
router_hidden_layers: int = 2,
|
|
@@ -101,8 +101,8 @@ class WeightEnsemblingMoE(nn.Module):
|
|
|
101
101
|
Args:
|
|
102
102
|
|
|
103
103
|
hidden_size (int): The size of the hidden layer in the models.
|
|
104
|
-
base_model (
|
|
105
|
-
expert_models (List[
|
|
104
|
+
base_model (TorchModelType): The base model that will be used as a reference for the expert models.
|
|
105
|
+
expert_models (List[TorchModelType]): A list of expert models that will be combined.
|
|
106
106
|
init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
|
|
107
107
|
batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
|
|
108
108
|
router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
|
|
@@ -145,7 +145,7 @@ class WeightEnsemblingMoE(nn.Module):
|
|
|
145
145
|
self._merged_state_dict,
|
|
146
146
|
)
|
|
147
147
|
|
|
148
|
-
def merge_weights(self, expert_weights):
|
|
148
|
+
def merge_weights(self, expert_weights) -> StateDictType:
|
|
149
149
|
state_dict = self.base_model.state_dict(keep_vars=True)
|
|
150
150
|
for weight, task_vector in zip(expert_weights, self.task_vectors):
|
|
151
151
|
for name, param in task_vector.named_parameters():
|
|
@@ -1,10 +1,17 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Callable, Dict, Generic, List, Union, cast
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import torch
|
|
6
|
+
import torch.futures
|
|
5
7
|
from omegaconf import ListConfig
|
|
6
8
|
from torch import Tensor, nn
|
|
7
9
|
|
|
10
|
+
from fusion_bench.utils.devices import to_device
|
|
11
|
+
from fusion_bench.utils.type import TorchModelType
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
8
15
|
|
|
9
16
|
def aggregate_tensors(
|
|
10
17
|
outputs: List[Any], aggregate_fn: Callable
|
|
@@ -58,12 +65,16 @@ def aggregate_tensors(
|
|
|
58
65
|
raise ValueError("Unsupported type for outputs")
|
|
59
66
|
|
|
60
67
|
|
|
61
|
-
class EnsembleModule(nn.Module):
|
|
68
|
+
class EnsembleModule(nn.Module, Generic[TorchModelType]):
|
|
62
69
|
"""
|
|
63
70
|
Ensemble module that averages the outputs of multiple models.
|
|
64
71
|
"""
|
|
65
72
|
|
|
66
|
-
def __init__(
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
models: List[TorchModelType],
|
|
76
|
+
device_map: Dict[int, Union[int, str]] | None = None,
|
|
77
|
+
):
|
|
67
78
|
"""
|
|
68
79
|
Initializes the EnsembleModule with a list of models.
|
|
69
80
|
|
|
@@ -73,6 +84,16 @@ class EnsembleModule(nn.Module):
|
|
|
73
84
|
super().__init__()
|
|
74
85
|
# TODO: distribute models to devices
|
|
75
86
|
self.model_list = nn.ModuleList(models)
|
|
87
|
+
self.device_map = device_map
|
|
88
|
+
if self.device_map is not None:
|
|
89
|
+
self._move_models_to_devices()
|
|
90
|
+
|
|
91
|
+
def _move_models_to_devices(self):
|
|
92
|
+
for model_idx, device_id in self.device_map.items():
|
|
93
|
+
log.info(f"Moving model {model_idx} to device {device_id}")
|
|
94
|
+
self.model_list[model_idx] = self.model_list[model_idx].to(
|
|
95
|
+
device_id, non_blocking=True
|
|
96
|
+
)
|
|
76
97
|
|
|
77
98
|
def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
|
|
78
99
|
"""
|
|
@@ -86,6 +107,49 @@ class EnsembleModule(nn.Module):
|
|
|
86
107
|
"""
|
|
87
108
|
return torch.stack(outputs).mean(dim=0)
|
|
88
109
|
|
|
110
|
+
def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
|
|
111
|
+
"""
|
|
112
|
+
Performs parallel forward pass using device mapping with futures.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
*args: Variable length argument list.
|
|
116
|
+
**kwargs: Arbitrary keyword arguments.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
List[Any]: List of outputs from all models, all moved to the same device.
|
|
120
|
+
"""
|
|
121
|
+
futures = []
|
|
122
|
+
|
|
123
|
+
device_data_cache = {}
|
|
124
|
+
for i, model in enumerate(self.model_list):
|
|
125
|
+
device_id = self.device_map.get(i, "cpu")
|
|
126
|
+
|
|
127
|
+
if device_id not in device_data_cache:
|
|
128
|
+
# Move inputs to the same device as the model
|
|
129
|
+
device_args = to_device(
|
|
130
|
+
args, device_id, copy_on_move=True, non_blocking=True
|
|
131
|
+
)
|
|
132
|
+
device_kwargs = to_device(
|
|
133
|
+
kwargs, device_id, copy_on_move=True, non_blocking=True
|
|
134
|
+
)
|
|
135
|
+
device_data_cache[device_id] = (device_args, device_kwargs)
|
|
136
|
+
else:
|
|
137
|
+
device_args, device_kwargs = device_data_cache[device_id]
|
|
138
|
+
|
|
139
|
+
# Create a future for asynchronous execution
|
|
140
|
+
future = torch.jit.fork(model, *device_args, **device_kwargs)
|
|
141
|
+
futures.append(future)
|
|
142
|
+
|
|
143
|
+
# Wait for all futures to complete and collect results
|
|
144
|
+
outputs = [torch.jit.wait(future) for future in futures]
|
|
145
|
+
|
|
146
|
+
# Move all outputs to the same device (use the device of the first model or cpu as fallback)
|
|
147
|
+
target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
|
|
148
|
+
outputs = [
|
|
149
|
+
to_device(output, target_device, non_blocking=True) for output in outputs
|
|
150
|
+
]
|
|
151
|
+
return outputs
|
|
152
|
+
|
|
89
153
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
90
154
|
"""
|
|
91
155
|
Performs a forward pass by averaging the outputs of the models.
|
|
@@ -97,20 +161,25 @@ class EnsembleModule(nn.Module):
|
|
|
97
161
|
Returns:
|
|
98
162
|
Aggregated output from the ensemble of models.
|
|
99
163
|
"""
|
|
100
|
-
|
|
164
|
+
if self.device_map is None:
|
|
165
|
+
outputs = [model(*args, **kwargs) for model in self.model_list]
|
|
166
|
+
else:
|
|
167
|
+
# Parallel execution with device mapping
|
|
168
|
+
outputs = self._parallel_forward_with_device_map(*args, **kwargs)
|
|
101
169
|
return aggregate_tensors(outputs, self._aggregate_tensors)
|
|
102
170
|
|
|
103
171
|
|
|
104
|
-
class WeightedEnsembleModule(nn.Module):
|
|
172
|
+
class WeightedEnsembleModule(nn.Module, Generic[TorchModelType]):
|
|
105
173
|
"""
|
|
106
174
|
Ensemble module that computes a weighted average of the outputs from multiple models.
|
|
107
175
|
"""
|
|
108
176
|
|
|
109
177
|
def __init__(
|
|
110
178
|
self,
|
|
111
|
-
models: List[
|
|
179
|
+
models: List[TorchModelType],
|
|
112
180
|
weights: List[float] | Tensor | np.ndarray,
|
|
113
181
|
normalize: bool = True,
|
|
182
|
+
device_map: Dict[int, Union[int, str]] | None = None,
|
|
114
183
|
):
|
|
115
184
|
"""
|
|
116
185
|
Initializes the WeightedEnsembleModule with models and their corresponding weights.
|
|
@@ -119,9 +188,12 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
119
188
|
models (List[nn.Module]): List of models to ensemble.
|
|
120
189
|
weights (List[float] | Tensor | np.ndarray): Weights for each model.
|
|
121
190
|
normalize (bool, optional): If True, normalizes the weights. Defaults to True.
|
|
191
|
+
device_map (Dict[int, Union[int, str]] | None, optional): Device mapping for parallel execution. Defaults to None.
|
|
122
192
|
"""
|
|
123
193
|
super().__init__()
|
|
124
194
|
self.model_list = nn.ModuleList(models)
|
|
195
|
+
self.device_map = device_map
|
|
196
|
+
|
|
125
197
|
if isinstance(weights, (list, tuple, ListConfig)):
|
|
126
198
|
weights = torch.tensor(weights)
|
|
127
199
|
elif isinstance(weights, Tensor):
|
|
@@ -139,6 +211,17 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
139
211
|
weights = weights / weights.sum()
|
|
140
212
|
self.register_buffer("weights", weights)
|
|
141
213
|
|
|
214
|
+
if self.device_map is not None:
|
|
215
|
+
self._move_models_to_devices()
|
|
216
|
+
|
|
217
|
+
def _move_models_to_devices(self):
|
|
218
|
+
"""Move models to their assigned devices according to device_map."""
|
|
219
|
+
for model_idx, device_id in self.device_map.items():
|
|
220
|
+
log.info(f"Moving model {model_idx} to device {device_id}")
|
|
221
|
+
self.model_list[model_idx] = self.model_list[model_idx].to(
|
|
222
|
+
device_id, non_blocking=True
|
|
223
|
+
)
|
|
224
|
+
|
|
142
225
|
def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
|
|
143
226
|
"""
|
|
144
227
|
Aggregates a list of tensors using the provided weights.
|
|
@@ -152,6 +235,48 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
152
235
|
weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
|
|
153
236
|
return (torch.stack(outputs) * weights).sum(dim=0)
|
|
154
237
|
|
|
238
|
+
def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
|
|
239
|
+
"""
|
|
240
|
+
Performs parallel forward pass using device mapping with futures.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
*args: Variable length argument list.
|
|
244
|
+
**kwargs: Arbitrary keyword arguments.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
List[Any]: List of outputs from all models, all moved to the same device.
|
|
248
|
+
"""
|
|
249
|
+
futures = []
|
|
250
|
+
|
|
251
|
+
device_data_cache = {}
|
|
252
|
+
for i, model in enumerate(self.model_list):
|
|
253
|
+
device_id = self.device_map.get(i, "cpu")
|
|
254
|
+
|
|
255
|
+
if device_id not in device_data_cache:
|
|
256
|
+
# Move inputs to the same device as the model
|
|
257
|
+
device_args = to_device(
|
|
258
|
+
args, device_id, copy_on_move=True, non_blocking=True
|
|
259
|
+
)
|
|
260
|
+
device_kwargs = to_device(
|
|
261
|
+
kwargs, device_id, copy_on_move=True, non_blocking=True
|
|
262
|
+
)
|
|
263
|
+
device_data_cache[device_id] = (device_args, device_kwargs)
|
|
264
|
+
else:
|
|
265
|
+
device_args, device_kwargs = device_data_cache[device_id]
|
|
266
|
+
|
|
267
|
+
# Create a future for asynchronous execution
|
|
268
|
+
future = torch.jit.fork(model, *device_args, **device_kwargs)
|
|
269
|
+
futures.append(future)
|
|
270
|
+
|
|
271
|
+
# Wait for all futures to complete and collect results
|
|
272
|
+
outputs = [torch.jit.wait(future) for future in futures]
|
|
273
|
+
|
|
274
|
+
# Move all outputs to the same device (use the device of the first model or cpu as fallback)
|
|
275
|
+
target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
|
|
276
|
+
outputs = [to_device(output, target_device) for output in outputs]
|
|
277
|
+
|
|
278
|
+
return outputs
|
|
279
|
+
|
|
155
280
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
156
281
|
"""
|
|
157
282
|
Performs a forward pass by computing the weighted average of the models' outputs.
|
|
@@ -163,7 +288,11 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
163
288
|
Returns:
|
|
164
289
|
Weighted aggregated output from the ensemble of models.
|
|
165
290
|
"""
|
|
166
|
-
|
|
291
|
+
if self.device_map is None:
|
|
292
|
+
outputs = [model(*args, **kwargs) for model in self.model_list]
|
|
293
|
+
else:
|
|
294
|
+
# Parallel execution with device mapping
|
|
295
|
+
outputs = self._parallel_forward_with_device_map(*args, **kwargs)
|
|
167
296
|
return aggregate_tensors(outputs, self._aggregate_tensors)
|
|
168
297
|
|
|
169
298
|
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -20,8 +20,8 @@ log = logging.getLogger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _get_default_config_path():
|
|
23
|
-
for
|
|
24
|
-
for
|
|
23
|
+
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
24
|
+
for config_dir in ["config", "fusion_bench_config"]:
|
|
25
25
|
config_path = os.path.join(config_path_root, config_dir)
|
|
26
26
|
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
27
27
|
return os.path.abspath(config_path)
|
|
@@ -5,33 +5,115 @@ from fusion_bench.mixins import BaseYAMLSerializable
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
class BaseTaskPool(BaseYAMLSerializable):
|
|
8
|
+
"""Abstract base class for task pools in the FusionBench framework.
|
|
9
|
+
|
|
10
|
+
A task pool represents a collection of evaluation tasks that can be used to
|
|
11
|
+
assess model performance across multiple benchmarks or datasets. This base
|
|
12
|
+
class defines the common interface that all task pool implementations must
|
|
13
|
+
follow, ensuring consistency across different task types and evaluation
|
|
14
|
+
scenarios.
|
|
15
|
+
|
|
16
|
+
Task pools are designed to be configurable through YAML files and can be
|
|
17
|
+
used in various model fusion and evaluation workflows. They provide a
|
|
18
|
+
standardized way to evaluate models on multiple tasks and aggregate results.
|
|
19
|
+
|
|
20
|
+
The class inherits from BaseYAMLSerializable to support configuration
|
|
21
|
+
management and serialization capabilities.
|
|
22
|
+
|
|
23
|
+
Attributes:
|
|
24
|
+
_program: Optional program reference for execution context.
|
|
25
|
+
_config_key: Configuration key used for YAML configuration ("taskpool").
|
|
26
|
+
|
|
27
|
+
Abstract Methods:
|
|
28
|
+
evaluate: Must be implemented by subclasses to define task-specific
|
|
29
|
+
evaluation logic.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
Implementing a custom task pool:
|
|
33
|
+
|
|
34
|
+
```python
|
|
35
|
+
class MyTaskPool(BaseTaskPool):
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def evaluate(self, model, **kwargs):
|
|
39
|
+
results = {}
|
|
40
|
+
for task_name in self.tasks:
|
|
41
|
+
# Implement task-specific evaluation
|
|
42
|
+
results[task_name] = self._evaluate_task(model, task_name)
|
|
43
|
+
return results
|
|
44
|
+
```
|
|
45
|
+
"""
|
|
46
|
+
|
|
8
47
|
_program = None
|
|
9
48
|
_config_key = "taskpool"
|
|
10
49
|
|
|
11
50
|
@abstractmethod
|
|
12
51
|
def evaluate(self, model: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
|
13
|
-
"""
|
|
14
|
-
Evaluate the model on all tasks in the task pool, and return a report.
|
|
52
|
+
"""Evaluate a model on all tasks in the task pool and return aggregated results.
|
|
15
53
|
|
|
16
|
-
|
|
54
|
+
This abstract method defines the core evaluation interface that all task pool
|
|
55
|
+
implementations must provide. It should evaluate the given model on all tasks
|
|
56
|
+
managed by the pool and return a structured report of the results.
|
|
17
57
|
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
},
|
|
24
|
-
<task_name>: {
|
|
25
|
-
<metric_name>: <metric_value>,
|
|
26
|
-
...
|
|
27
|
-
},
|
|
28
|
-
}
|
|
29
|
-
```
|
|
58
|
+
The evaluation process typically involves:
|
|
59
|
+
1. Iterating through all tasks in the pool
|
|
60
|
+
2. Running model inference on each task's dataset
|
|
61
|
+
3. Computing task-specific metrics
|
|
62
|
+
4. Aggregating results into a standardized report format
|
|
30
63
|
|
|
31
64
|
Args:
|
|
32
|
-
model: The model to evaluate.
|
|
65
|
+
model: The model to evaluate. Can be any model type (PyTorch model,
|
|
66
|
+
Hugging Face model, etc.) that is compatible with the specific
|
|
67
|
+
task pool implementation.
|
|
68
|
+
*args: Additional positional arguments that may be needed for
|
|
69
|
+
task-specific evaluation procedures.
|
|
70
|
+
**kwargs: Additional keyword arguments for evaluation configuration,
|
|
71
|
+
such as batch_size, device, evaluation metrics, etc.
|
|
33
72
|
|
|
34
73
|
Returns:
|
|
35
|
-
|
|
74
|
+
Dict[str, Any]: A dictionary containing evaluation results for each task.
|
|
75
|
+
The structure follows the pattern:
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
{
|
|
79
|
+
"task_name_1": {
|
|
80
|
+
"metric_1": value,
|
|
81
|
+
"metric_2": value,
|
|
82
|
+
...
|
|
83
|
+
},
|
|
84
|
+
"task_name_2": {
|
|
85
|
+
"metric_1": value,
|
|
86
|
+
"metric_2": value,
|
|
87
|
+
...
|
|
88
|
+
},
|
|
89
|
+
...
|
|
90
|
+
}
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
For an image classification task pool:
|
|
95
|
+
|
|
96
|
+
```python
|
|
97
|
+
results = task_pool.evaluate(model)
|
|
98
|
+
# Returns:
|
|
99
|
+
# {
|
|
100
|
+
# "mnist": {
|
|
101
|
+
# "accuracy": 0.95,
|
|
102
|
+
# "loss": 0.15,
|
|
103
|
+
# },
|
|
104
|
+
# "cifar10": {
|
|
105
|
+
# "accuracy": 0.87,
|
|
106
|
+
# "loss": 0.42,
|
|
107
|
+
# }
|
|
108
|
+
# }
|
|
109
|
+
```
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
NotImplementedError: This method must be implemented by subclasses.
|
|
113
|
+
|
|
114
|
+
Note:
|
|
115
|
+
Implementations should ensure that the returned dictionary structure
|
|
116
|
+
is consistent and that metric names are standardized across similar
|
|
117
|
+
task types to enable meaningful comparison and aggregation.
|
|
36
118
|
"""
|
|
37
119
|
pass
|
|
@@ -27,7 +27,7 @@ from tqdm.autonotebook import tqdm
|
|
|
27
27
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
28
28
|
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
29
29
|
|
|
30
|
-
from fusion_bench import RuntimeConstants
|
|
30
|
+
from fusion_bench import RuntimeConstants, auto_register_config
|
|
31
31
|
from fusion_bench.dataset import CLIPDataset
|
|
32
32
|
from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
|
|
33
33
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
@@ -86,6 +86,7 @@ class LayerWiseFeatureSaver:
|
|
|
86
86
|
torch.save(features, self.save_path)
|
|
87
87
|
|
|
88
88
|
|
|
89
|
+
@auto_register_config
|
|
89
90
|
class CLIPVisionModelTaskPool(
|
|
90
91
|
HydraConfigMixin,
|
|
91
92
|
LightningFabricMixin,
|
|
@@ -134,11 +135,13 @@ class CLIPVisionModelTaskPool(
|
|
|
134
135
|
layer_wise_feature_first_token_only: bool = True,
|
|
135
136
|
layer_wise_feature_max_num: Optional[int] = None,
|
|
136
137
|
fast_dev_run: Optional[bool] = None,
|
|
138
|
+
move_to_device: bool = True,
|
|
137
139
|
**kwargs,
|
|
138
140
|
):
|
|
139
141
|
"""
|
|
140
142
|
Initialize the CLIPVisionModelTaskPool.
|
|
141
143
|
"""
|
|
144
|
+
super().__init__(**kwargs)
|
|
142
145
|
self._test_datasets = test_datasets
|
|
143
146
|
self._processor = processor
|
|
144
147
|
self._data_processor = data_processor
|
|
@@ -159,7 +162,6 @@ class CLIPVisionModelTaskPool(
|
|
|
159
162
|
self.fast_dev_run = RuntimeConstants().debug
|
|
160
163
|
else:
|
|
161
164
|
self.fast_dev_run = fast_dev_run
|
|
162
|
-
super().__init__(**kwargs)
|
|
163
165
|
|
|
164
166
|
def setup(self):
|
|
165
167
|
"""
|
|
@@ -220,7 +222,9 @@ class CLIPVisionModelTaskPool(
|
|
|
220
222
|
for name, dataset in self.test_datasets.items()
|
|
221
223
|
}
|
|
222
224
|
self.test_dataloaders = {
|
|
223
|
-
name: self.fabric.setup_dataloaders(
|
|
225
|
+
name: self.fabric.setup_dataloaders(
|
|
226
|
+
dataloader, move_to_device=self.move_to_device
|
|
227
|
+
)
|
|
224
228
|
for name, dataloader in self.test_dataloaders.items()
|
|
225
229
|
}
|
|
226
230
|
|
|
@@ -273,6 +277,8 @@ class CLIPVisionModelTaskPool(
|
|
|
273
277
|
task_name=task_name,
|
|
274
278
|
)
|
|
275
279
|
logits: Tensor = outputs["logits"]
|
|
280
|
+
if logits.device != targets.device:
|
|
281
|
+
targets = targets.to(logits.device)
|
|
276
282
|
|
|
277
283
|
loss = F.cross_entropy(logits, targets)
|
|
278
284
|
loss_metric.update(loss.detach().cpu())
|
|
@@ -309,7 +315,7 @@ class CLIPVisionModelTaskPool(
|
|
|
309
315
|
self.setup()
|
|
310
316
|
|
|
311
317
|
report = {}
|
|
312
|
-
# CLIPVisionModel works the same with
|
|
318
|
+
# CLIPVisionModel works the same with CLIPVisionTransformer, so we can use it directly
|
|
313
319
|
if hasattr(model, "is_surgery_model") and model.is_surgery_model:
|
|
314
320
|
log.info("running evaluation on a surgery model.")
|
|
315
321
|
model: "SurgeryModelWrapper" = model
|
|
@@ -321,7 +327,8 @@ class CLIPVisionModelTaskPool(
|
|
|
321
327
|
self.clip_model,
|
|
322
328
|
processor=self.processor,
|
|
323
329
|
)
|
|
324
|
-
|
|
330
|
+
if self.move_to_device:
|
|
331
|
+
classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
|
|
325
332
|
# collect basic model information
|
|
326
333
|
training_params, all_params = count_parameters(model)
|
|
327
334
|
report["model_info"] = {
|