fusion-bench 0.2.23__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +8 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_state_dict.py +3 -0
- fusion_bench/utils/rich_utils.py +7 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +37 -30
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
|
@@ -6,6 +6,7 @@ from inspect import Parameter, _ParameterKind
|
|
|
6
6
|
from pathlib import Path
|
|
7
7
|
from typing import Dict, Mapping, Optional, Union
|
|
8
8
|
|
|
9
|
+
from bidict import MutableBidict, bidict
|
|
9
10
|
from omegaconf import DictConfig, OmegaConf
|
|
10
11
|
|
|
11
12
|
from fusion_bench.constants import FUSION_BENCH_VERSION
|
|
@@ -15,22 +16,29 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
|
15
16
|
log = logging.getLogger(__name__)
|
|
16
17
|
|
|
17
18
|
__all__ = [
|
|
18
|
-
"YAMLSerializationMixin",
|
|
19
19
|
"auto_register_config",
|
|
20
|
+
"YAMLSerializationMixin",
|
|
20
21
|
"BaseYAMLSerializable",
|
|
21
22
|
]
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
def
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
25
|
+
def _set_attr(self, param_name: str, value):
|
|
26
|
+
"""
|
|
27
|
+
Set an attribute on the object using the parameter name from config mapping.
|
|
28
|
+
|
|
29
|
+
This function looks up the corresponding attribute name for the given parameter
|
|
30
|
+
name using the object's _config_mapping, then sets that attribute to the
|
|
31
|
+
specified value. It also logs the operation for debugging purposes.
|
|
30
32
|
|
|
33
|
+
Args:
|
|
34
|
+
self: The object instance to set the attribute on.
|
|
35
|
+
param_name (str): The parameter name (config key) to map to an attribute.
|
|
36
|
+
value: The value to assign to the attribute.
|
|
31
37
|
|
|
32
|
-
|
|
33
|
-
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If the parameter name is not found in the config mapping.
|
|
40
|
+
"""
|
|
41
|
+
attr_name = self._config_mapping.inverse[param_name]
|
|
34
42
|
log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
|
|
35
43
|
setattr(self, attr_name, value)
|
|
36
44
|
|
|
@@ -59,37 +67,16 @@ def auto_register_config(cls):
|
|
|
59
67
|
functionality and modified __init__ behavior.
|
|
60
68
|
|
|
61
69
|
Behavior:
|
|
62
|
-
- **Parameter Registration**: All non-variadic parameters (excluding
|
|
70
|
+
- **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
|
|
63
71
|
from the __init__ method are automatically added to _config_mapping
|
|
64
72
|
- **Positional Arguments**: Handled in order and mapped to corresponding parameter names
|
|
65
73
|
- **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
|
|
66
74
|
- **Default Values**: Applied when parameters are not provided via arguments
|
|
67
75
|
- **Attribute Setting**: All parameters become instance attributes accessible via dot notation
|
|
68
76
|
|
|
69
|
-
Example:
|
|
70
|
-
```python
|
|
71
|
-
@auto_register_config
|
|
72
|
-
class MyAlgorithm(BaseYAMLSerializable):
|
|
73
|
-
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
|
|
74
|
-
super().__init__(**kwargs)
|
|
75
|
-
|
|
76
|
-
# All instantiation methods work automatically:
|
|
77
|
-
algo1 = MyAlgorithm(0.01, 64) # positional args
|
|
78
|
-
algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
|
|
79
|
-
algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
|
|
80
|
-
|
|
81
|
-
# Attributes are automatically set and can be serialized:
|
|
82
|
-
print(algo1.learning_rate) # 0.01
|
|
83
|
-
print(algo1.batch_size) # 64
|
|
84
|
-
print(algo1.model_name) # "default" (from default value)
|
|
85
|
-
|
|
86
|
-
config = algo1.config
|
|
87
|
-
# DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
|
|
88
|
-
```
|
|
89
|
-
|
|
90
77
|
Note:
|
|
91
78
|
- The decorator wraps the original __init__ method while preserving its signature for IDE support
|
|
92
|
-
- Parameters with
|
|
79
|
+
- Parameters with ``*args`` or ``**kwargs`` signatures are ignored during registration
|
|
93
80
|
- The attributes are auto-registered, then the original __init__ method is called,
|
|
94
81
|
- Type hints, method name, and other metadata are preserved using functools.wraps
|
|
95
82
|
- This decorator is designed to work seamlessly with the YAML serialization system
|
|
@@ -103,7 +90,10 @@ def auto_register_config(cls):
|
|
|
103
90
|
|
|
104
91
|
# Auto-register parameters in _config_mapping
|
|
105
92
|
if not "_config_mapping" in cls.__dict__:
|
|
106
|
-
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping",
|
|
93
|
+
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", bidict()))
|
|
94
|
+
if not isinstance(cls._config_mapping, bidict):
|
|
95
|
+
cls._config_mapping = bidict(cls._config_mapping)
|
|
96
|
+
|
|
107
97
|
registered_parameters = tuple(cls._config_mapping.values())
|
|
108
98
|
|
|
109
99
|
for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
|
|
@@ -116,6 +106,7 @@ def auto_register_config(cls):
|
|
|
116
106
|
) and (param_name not in registered_parameters):
|
|
117
107
|
cls._config_mapping[param_name] = param_name
|
|
118
108
|
|
|
109
|
+
@wraps(original_init)
|
|
119
110
|
def __init__(self, *args, **kwargs):
|
|
120
111
|
log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
|
|
121
112
|
# auto-register the attributes based on the signature
|
|
@@ -162,33 +153,10 @@ def auto_register_config(cls):
|
|
|
162
153
|
|
|
163
154
|
class YAMLSerializationMixin:
|
|
164
155
|
_config_key: Optional[str] = None
|
|
165
|
-
_config_mapping:
|
|
156
|
+
_config_mapping: MutableBidict[str, str] = bidict()
|
|
166
157
|
R"""
|
|
167
158
|
`_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
|
|
168
159
|
|
|
169
|
-
For example, if an algorithm class is defined as follows:
|
|
170
|
-
|
|
171
|
-
```python
|
|
172
|
-
class SomeModelFusionAlgorithm(BaseModelFusionAlgorithm):
|
|
173
|
-
hyper_parameter_1 = None
|
|
174
|
-
hyper_parameter_2 = None
|
|
175
|
-
|
|
176
|
-
_config_mapping = BaseModelFusionAlgorithm._config_mapping | {
|
|
177
|
-
"hyper_parameter_1" : "hyper_param_1",
|
|
178
|
-
"hyper_parameter_2" : "hyper_param_2",
|
|
179
|
-
}
|
|
180
|
-
def __init__(self, hyper_param_1: int, hyper_param_2: int):
|
|
181
|
-
self.hyper_parameter_1 = hyper_param_1
|
|
182
|
-
self.hyper_parameter_2 = hyper_param_2
|
|
183
|
-
super().__init__()
|
|
184
|
-
```
|
|
185
|
-
|
|
186
|
-
The model pool will be converted to a DictConfig as follows:
|
|
187
|
-
|
|
188
|
-
```python
|
|
189
|
-
algorithm = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
|
|
190
|
-
```
|
|
191
|
-
|
|
192
160
|
>>> algorithm.config
|
|
193
161
|
DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
|
|
194
162
|
|
|
@@ -207,17 +175,6 @@ class YAMLSerializationMixin:
|
|
|
207
175
|
This property converts the model pool instance into a dictionary
|
|
208
176
|
configuration, which can be used for serialization or other purposes.
|
|
209
177
|
|
|
210
|
-
Example:
|
|
211
|
-
|
|
212
|
-
```python
|
|
213
|
-
model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
|
|
214
|
-
config = model.config
|
|
215
|
-
print(config)
|
|
216
|
-
# DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
|
|
217
|
-
```
|
|
218
|
-
|
|
219
|
-
This is useful for serializing the object to a YAML file or for debugging.
|
|
220
|
-
|
|
221
178
|
Returns:
|
|
222
179
|
DictConfig: The configuration of the model pool.
|
|
223
180
|
"""
|
|
@@ -282,16 +239,6 @@ class YAMLSerializationMixin:
|
|
|
282
239
|
serialization. This is how the attribute will appear in YAML output.
|
|
283
240
|
value: The value to assign to the attribute.
|
|
284
241
|
|
|
285
|
-
Example:
|
|
286
|
-
```python
|
|
287
|
-
model = BaseYAMLSerializable()
|
|
288
|
-
model.set_option("learning_rate", "lr", 0.001)
|
|
289
|
-
|
|
290
|
-
# This sets model.learning_rate = 0.001
|
|
291
|
-
# and maps it to "lr" in the config output
|
|
292
|
-
config = model.config
|
|
293
|
-
# config will contain: {"lr": 0.001, ...}
|
|
294
|
-
```
|
|
295
242
|
"""
|
|
296
243
|
setattr(self, attr_name, value)
|
|
297
244
|
self._config_mapping[attr_name] = param_name
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
"""
|
|
2
|
-
Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/
|
|
2
|
+
Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/llm
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
import logging
|
|
@@ -26,6 +26,7 @@ from fusion_bench import (
|
|
|
26
26
|
instantiate,
|
|
27
27
|
parse_dtype,
|
|
28
28
|
)
|
|
29
|
+
from fusion_bench.models.hf_utils import create_default_model_card
|
|
29
30
|
from fusion_bench.utils.lazy_state_dict import LazyStateDict
|
|
30
31
|
|
|
31
32
|
log = logging.getLogger(__name__)
|
|
@@ -271,13 +272,16 @@ class CausalLMPool(BaseModelPool):
|
|
|
271
272
|
save_tokenizer: bool = False,
|
|
272
273
|
tokenizer_kwargs=None,
|
|
273
274
|
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
275
|
+
algorithm_config: Optional[DictConfig] = None,
|
|
276
|
+
description: Optional[str] = None,
|
|
277
|
+
base_model_in_modelcard: bool = True,
|
|
274
278
|
**kwargs,
|
|
275
279
|
):
|
|
276
280
|
"""Save a model to the specified path with optional tokenizer and Hub upload.
|
|
277
281
|
|
|
278
282
|
This method provides comprehensive model saving capabilities including
|
|
279
|
-
optional tokenizer saving, dtype conversion,
|
|
280
|
-
The model is saved in the standard Hugging Face format.
|
|
283
|
+
optional tokenizer saving, dtype conversion, model card creation, and
|
|
284
|
+
Hugging Face Hub upload. The model is saved in the standard Hugging Face format.
|
|
281
285
|
|
|
282
286
|
Args:
|
|
283
287
|
model: The PreTrainedModel instance to be saved.
|
|
@@ -295,15 +299,13 @@ class CausalLMPool(BaseModelPool):
|
|
|
295
299
|
when save_tokenizer is True.
|
|
296
300
|
tokenizer: Optional pre-loaded tokenizer instance. If provided, this
|
|
297
301
|
tokenizer will be saved regardless of the save_tokenizer flag.
|
|
302
|
+
algorithm_config: Optional DictConfig containing algorithm configuration.
|
|
303
|
+
If provided, a model card will be created with algorithm details.
|
|
304
|
+
description: Optional description for the model card. If not provided
|
|
305
|
+
and algorithm_config is given, a default description will be generated.
|
|
298
306
|
**kwargs: Additional keyword arguments passed to the model's
|
|
299
307
|
save_pretrained method.
|
|
300
308
|
|
|
301
|
-
Side Effects:
|
|
302
|
-
- Creates model files in the specified directory
|
|
303
|
-
- Optionally creates tokenizer files in the same directory
|
|
304
|
-
- May convert the model to a different dtype
|
|
305
|
-
- May upload files to Hugging Face Hub
|
|
306
|
-
|
|
307
309
|
Example:
|
|
308
310
|
```python
|
|
309
311
|
>>> pool = CausalLMPool(models=..., tokenizer=...)
|
|
@@ -313,7 +315,9 @@ class CausalLMPool(BaseModelPool):
|
|
|
313
315
|
... "/path/to/save",
|
|
314
316
|
... save_tokenizer=True,
|
|
315
317
|
... model_dtype="float16",
|
|
316
|
-
... push_to_hub=True
|
|
318
|
+
... push_to_hub=True,
|
|
319
|
+
... algorithm_config=algorithm_config,
|
|
320
|
+
... description="Custom merged model"
|
|
317
321
|
... )
|
|
318
322
|
```
|
|
319
323
|
"""
|
|
@@ -337,6 +341,24 @@ class CausalLMPool(BaseModelPool):
|
|
|
337
341
|
**kwargs,
|
|
338
342
|
)
|
|
339
343
|
|
|
344
|
+
# Create and save model card if algorithm_config is provided
|
|
345
|
+
if algorithm_config is not None:
|
|
346
|
+
if description is None:
|
|
347
|
+
description = "Model created using FusionBench."
|
|
348
|
+
model_card_str = create_default_model_card(
|
|
349
|
+
base_model=(
|
|
350
|
+
self.get_model_path("_pretrained_")
|
|
351
|
+
if base_model_in_modelcard and self.has_pretrained
|
|
352
|
+
else None
|
|
353
|
+
),
|
|
354
|
+
models=[self.get_model_path(m) for m in self.model_names],
|
|
355
|
+
description=description,
|
|
356
|
+
algorithm_config=algorithm_config,
|
|
357
|
+
modelpool_config=self.config,
|
|
358
|
+
)
|
|
359
|
+
with open(os.path.join(path, "README.md"), "w") as f:
|
|
360
|
+
f.write(model_card_str)
|
|
361
|
+
|
|
340
362
|
|
|
341
363
|
class CausalLMBackbonePool(CausalLMPool):
|
|
342
364
|
"""A specialized model pool that loads only the transformer backbone layers.
|
fusion_bench/models/hf_clip.py
CHANGED
|
@@ -195,5 +195,9 @@ class HFCLIPClassifier(nn.Module):
|
|
|
195
195
|
pass
|
|
196
196
|
elif isinstance(image_embeds, BaseModelOutputWithPooling):
|
|
197
197
|
image_embeds = image_embeds[1]
|
|
198
|
+
elif isinstance(image_embeds, dict) and "pooler_output" in image_embeds:
|
|
199
|
+
image_embeds = image_embeds["pooler_output"]
|
|
200
|
+
else:
|
|
201
|
+
raise ValueError("Unsupported output type from vision model outputs")
|
|
198
202
|
image_embeds = self.clip_model.visual_projection(image_embeds)
|
|
199
203
|
return image_embeds
|
fusion_bench/models/hf_utils.py
CHANGED
|
@@ -143,7 +143,7 @@ def save_pretrained_with_remote_code(
|
|
|
143
143
|
|
|
144
144
|
def create_default_model_card(
|
|
145
145
|
models: list[str],
|
|
146
|
-
|
|
146
|
+
base_model: Optional[str] = None,
|
|
147
147
|
title: str = "Deep Model Fusion",
|
|
148
148
|
tags: list[str] = ["fusion-bench", "merge"],
|
|
149
149
|
description=None,
|
|
@@ -154,6 +154,7 @@ def create_default_model_card(
|
|
|
154
154
|
|
|
155
155
|
template: Template = Template(load_model_card_template("default.md"))
|
|
156
156
|
card = template.render(
|
|
157
|
+
base_model=base_model,
|
|
157
158
|
models=models,
|
|
158
159
|
library_name="transformers",
|
|
159
160
|
title=title,
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
---
|
|
2
2
|
base_model:
|
|
3
|
+
{%- if base_model is not none %}
|
|
4
|
+
- {{ base_model }}
|
|
5
|
+
{%- endif %}
|
|
3
6
|
{%- for model in models %}
|
|
4
7
|
- {{ model }}
|
|
5
8
|
{%- endfor %}
|
|
@@ -18,7 +21,11 @@ tags:
|
|
|
18
21
|
This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
19
22
|
|
|
20
23
|
The following models were included in the merge:
|
|
21
|
-
|
|
24
|
+
|
|
25
|
+
{% if base_model is not none %}
|
|
26
|
+
- base model: {{ base_model }}
|
|
27
|
+
{%- endif %}
|
|
28
|
+
{%- for model in models %}
|
|
22
29
|
- {{ model }}
|
|
23
30
|
{%- endfor %}
|
|
24
31
|
|
|
@@ -1,10 +1,17 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Callable, Dict, Generic, List, Union, cast
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import torch
|
|
6
|
+
import torch.futures
|
|
5
7
|
from omegaconf import ListConfig
|
|
6
8
|
from torch import Tensor, nn
|
|
7
9
|
|
|
10
|
+
from fusion_bench.utils.devices import to_device
|
|
11
|
+
from fusion_bench.utils.type import TorchModelType
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
8
15
|
|
|
9
16
|
def aggregate_tensors(
|
|
10
17
|
outputs: List[Any], aggregate_fn: Callable
|
|
@@ -58,12 +65,16 @@ def aggregate_tensors(
|
|
|
58
65
|
raise ValueError("Unsupported type for outputs")
|
|
59
66
|
|
|
60
67
|
|
|
61
|
-
class EnsembleModule(nn.Module):
|
|
68
|
+
class EnsembleModule(nn.Module, Generic[TorchModelType]):
|
|
62
69
|
"""
|
|
63
70
|
Ensemble module that averages the outputs of multiple models.
|
|
64
71
|
"""
|
|
65
72
|
|
|
66
|
-
def __init__(
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
models: List[TorchModelType],
|
|
76
|
+
device_map: Dict[int, Union[int, str]] | None = None,
|
|
77
|
+
):
|
|
67
78
|
"""
|
|
68
79
|
Initializes the EnsembleModule with a list of models.
|
|
69
80
|
|
|
@@ -73,6 +84,16 @@ class EnsembleModule(nn.Module):
|
|
|
73
84
|
super().__init__()
|
|
74
85
|
# TODO: distribute models to devices
|
|
75
86
|
self.model_list = nn.ModuleList(models)
|
|
87
|
+
self.device_map = device_map
|
|
88
|
+
if self.device_map is not None:
|
|
89
|
+
self._move_models_to_devices()
|
|
90
|
+
|
|
91
|
+
def _move_models_to_devices(self):
|
|
92
|
+
for model_idx, device_id in self.device_map.items():
|
|
93
|
+
log.info(f"Moving model {model_idx} to device {device_id}")
|
|
94
|
+
self.model_list[model_idx] = self.model_list[model_idx].to(
|
|
95
|
+
device_id, non_blocking=True
|
|
96
|
+
)
|
|
76
97
|
|
|
77
98
|
def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
|
|
78
99
|
"""
|
|
@@ -86,6 +107,49 @@ class EnsembleModule(nn.Module):
|
|
|
86
107
|
"""
|
|
87
108
|
return torch.stack(outputs).mean(dim=0)
|
|
88
109
|
|
|
110
|
+
def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
|
|
111
|
+
"""
|
|
112
|
+
Performs parallel forward pass using device mapping with futures.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
*args: Variable length argument list.
|
|
116
|
+
**kwargs: Arbitrary keyword arguments.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
List[Any]: List of outputs from all models, all moved to the same device.
|
|
120
|
+
"""
|
|
121
|
+
futures = []
|
|
122
|
+
|
|
123
|
+
device_data_cache = {}
|
|
124
|
+
for i, model in enumerate(self.model_list):
|
|
125
|
+
device_id = self.device_map.get(i, "cpu")
|
|
126
|
+
|
|
127
|
+
if device_id not in device_data_cache:
|
|
128
|
+
# Move inputs to the same device as the model
|
|
129
|
+
device_args = to_device(
|
|
130
|
+
args, device_id, copy_on_move=True, non_blocking=True
|
|
131
|
+
)
|
|
132
|
+
device_kwargs = to_device(
|
|
133
|
+
kwargs, device_id, copy_on_move=True, non_blocking=True
|
|
134
|
+
)
|
|
135
|
+
device_data_cache[device_id] = (device_args, device_kwargs)
|
|
136
|
+
else:
|
|
137
|
+
device_args, device_kwargs = device_data_cache[device_id]
|
|
138
|
+
|
|
139
|
+
# Create a future for asynchronous execution
|
|
140
|
+
future = torch.jit.fork(model, *device_args, **device_kwargs)
|
|
141
|
+
futures.append(future)
|
|
142
|
+
|
|
143
|
+
# Wait for all futures to complete and collect results
|
|
144
|
+
outputs = [torch.jit.wait(future) for future in futures]
|
|
145
|
+
|
|
146
|
+
# Move all outputs to the same device (use the device of the first model or cpu as fallback)
|
|
147
|
+
target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
|
|
148
|
+
outputs = [
|
|
149
|
+
to_device(output, target_device, non_blocking=True) for output in outputs
|
|
150
|
+
]
|
|
151
|
+
return outputs
|
|
152
|
+
|
|
89
153
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
90
154
|
"""
|
|
91
155
|
Performs a forward pass by averaging the outputs of the models.
|
|
@@ -97,20 +161,25 @@ class EnsembleModule(nn.Module):
|
|
|
97
161
|
Returns:
|
|
98
162
|
Aggregated output from the ensemble of models.
|
|
99
163
|
"""
|
|
100
|
-
|
|
164
|
+
if self.device_map is None:
|
|
165
|
+
outputs = [model(*args, **kwargs) for model in self.model_list]
|
|
166
|
+
else:
|
|
167
|
+
# Parallel execution with device mapping
|
|
168
|
+
outputs = self._parallel_forward_with_device_map(*args, **kwargs)
|
|
101
169
|
return aggregate_tensors(outputs, self._aggregate_tensors)
|
|
102
170
|
|
|
103
171
|
|
|
104
|
-
class WeightedEnsembleModule(nn.Module):
|
|
172
|
+
class WeightedEnsembleModule(nn.Module, Generic[TorchModelType]):
|
|
105
173
|
"""
|
|
106
174
|
Ensemble module that computes a weighted average of the outputs from multiple models.
|
|
107
175
|
"""
|
|
108
176
|
|
|
109
177
|
def __init__(
|
|
110
178
|
self,
|
|
111
|
-
models: List[
|
|
179
|
+
models: List[TorchModelType],
|
|
112
180
|
weights: List[float] | Tensor | np.ndarray,
|
|
113
181
|
normalize: bool = True,
|
|
182
|
+
device_map: Dict[int, Union[int, str]] | None = None,
|
|
114
183
|
):
|
|
115
184
|
"""
|
|
116
185
|
Initializes the WeightedEnsembleModule with models and their corresponding weights.
|
|
@@ -119,9 +188,12 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
119
188
|
models (List[nn.Module]): List of models to ensemble.
|
|
120
189
|
weights (List[float] | Tensor | np.ndarray): Weights for each model.
|
|
121
190
|
normalize (bool, optional): If True, normalizes the weights. Defaults to True.
|
|
191
|
+
device_map (Dict[int, Union[int, str]] | None, optional): Device mapping for parallel execution. Defaults to None.
|
|
122
192
|
"""
|
|
123
193
|
super().__init__()
|
|
124
194
|
self.model_list = nn.ModuleList(models)
|
|
195
|
+
self.device_map = device_map
|
|
196
|
+
|
|
125
197
|
if isinstance(weights, (list, tuple, ListConfig)):
|
|
126
198
|
weights = torch.tensor(weights)
|
|
127
199
|
elif isinstance(weights, Tensor):
|
|
@@ -139,6 +211,17 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
139
211
|
weights = weights / weights.sum()
|
|
140
212
|
self.register_buffer("weights", weights)
|
|
141
213
|
|
|
214
|
+
if self.device_map is not None:
|
|
215
|
+
self._move_models_to_devices()
|
|
216
|
+
|
|
217
|
+
def _move_models_to_devices(self):
|
|
218
|
+
"""Move models to their assigned devices according to device_map."""
|
|
219
|
+
for model_idx, device_id in self.device_map.items():
|
|
220
|
+
log.info(f"Moving model {model_idx} to device {device_id}")
|
|
221
|
+
self.model_list[model_idx] = self.model_list[model_idx].to(
|
|
222
|
+
device_id, non_blocking=True
|
|
223
|
+
)
|
|
224
|
+
|
|
142
225
|
def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
|
|
143
226
|
"""
|
|
144
227
|
Aggregates a list of tensors using the provided weights.
|
|
@@ -152,6 +235,48 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
152
235
|
weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
|
|
153
236
|
return (torch.stack(outputs) * weights).sum(dim=0)
|
|
154
237
|
|
|
238
|
+
def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
|
|
239
|
+
"""
|
|
240
|
+
Performs parallel forward pass using device mapping with futures.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
*args: Variable length argument list.
|
|
244
|
+
**kwargs: Arbitrary keyword arguments.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
List[Any]: List of outputs from all models, all moved to the same device.
|
|
248
|
+
"""
|
|
249
|
+
futures = []
|
|
250
|
+
|
|
251
|
+
device_data_cache = {}
|
|
252
|
+
for i, model in enumerate(self.model_list):
|
|
253
|
+
device_id = self.device_map.get(i, "cpu")
|
|
254
|
+
|
|
255
|
+
if device_id not in device_data_cache:
|
|
256
|
+
# Move inputs to the same device as the model
|
|
257
|
+
device_args = to_device(
|
|
258
|
+
args, device_id, copy_on_move=True, non_blocking=True
|
|
259
|
+
)
|
|
260
|
+
device_kwargs = to_device(
|
|
261
|
+
kwargs, device_id, copy_on_move=True, non_blocking=True
|
|
262
|
+
)
|
|
263
|
+
device_data_cache[device_id] = (device_args, device_kwargs)
|
|
264
|
+
else:
|
|
265
|
+
device_args, device_kwargs = device_data_cache[device_id]
|
|
266
|
+
|
|
267
|
+
# Create a future for asynchronous execution
|
|
268
|
+
future = torch.jit.fork(model, *device_args, **device_kwargs)
|
|
269
|
+
futures.append(future)
|
|
270
|
+
|
|
271
|
+
# Wait for all futures to complete and collect results
|
|
272
|
+
outputs = [torch.jit.wait(future) for future in futures]
|
|
273
|
+
|
|
274
|
+
# Move all outputs to the same device (use the device of the first model or cpu as fallback)
|
|
275
|
+
target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
|
|
276
|
+
outputs = [to_device(output, target_device) for output in outputs]
|
|
277
|
+
|
|
278
|
+
return outputs
|
|
279
|
+
|
|
155
280
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
156
281
|
"""
|
|
157
282
|
Performs a forward pass by computing the weighted average of the models' outputs.
|
|
@@ -163,7 +288,11 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
163
288
|
Returns:
|
|
164
289
|
Weighted aggregated output from the ensemble of models.
|
|
165
290
|
"""
|
|
166
|
-
|
|
291
|
+
if self.device_map is None:
|
|
292
|
+
outputs = [model(*args, **kwargs) for model in self.model_list]
|
|
293
|
+
else:
|
|
294
|
+
# Parallel execution with device mapping
|
|
295
|
+
outputs = self._parallel_forward_with_device_map(*args, **kwargs)
|
|
167
296
|
return aggregate_tensors(outputs, self._aggregate_tensors)
|
|
168
297
|
|
|
169
298
|
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -20,8 +20,8 @@ log = logging.getLogger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _get_default_config_path():
|
|
23
|
-
for
|
|
24
|
-
for
|
|
23
|
+
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
24
|
+
for config_dir in ["config", "fusion_bench_config"]:
|
|
25
25
|
config_path = os.path.join(config_path_root, config_dir)
|
|
26
26
|
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
27
27
|
return os.path.abspath(config_path)
|
|
@@ -27,7 +27,7 @@ from tqdm.autonotebook import tqdm
|
|
|
27
27
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
28
28
|
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
29
29
|
|
|
30
|
-
from fusion_bench import RuntimeConstants
|
|
30
|
+
from fusion_bench import RuntimeConstants, auto_register_config
|
|
31
31
|
from fusion_bench.dataset import CLIPDataset
|
|
32
32
|
from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
|
|
33
33
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
@@ -86,6 +86,7 @@ class LayerWiseFeatureSaver:
|
|
|
86
86
|
torch.save(features, self.save_path)
|
|
87
87
|
|
|
88
88
|
|
|
89
|
+
@auto_register_config
|
|
89
90
|
class CLIPVisionModelTaskPool(
|
|
90
91
|
HydraConfigMixin,
|
|
91
92
|
LightningFabricMixin,
|
|
@@ -134,11 +135,13 @@ class CLIPVisionModelTaskPool(
|
|
|
134
135
|
layer_wise_feature_first_token_only: bool = True,
|
|
135
136
|
layer_wise_feature_max_num: Optional[int] = None,
|
|
136
137
|
fast_dev_run: Optional[bool] = None,
|
|
138
|
+
move_to_device: bool = True,
|
|
137
139
|
**kwargs,
|
|
138
140
|
):
|
|
139
141
|
"""
|
|
140
142
|
Initialize the CLIPVisionModelTaskPool.
|
|
141
143
|
"""
|
|
144
|
+
super().__init__(**kwargs)
|
|
142
145
|
self._test_datasets = test_datasets
|
|
143
146
|
self._processor = processor
|
|
144
147
|
self._data_processor = data_processor
|
|
@@ -159,7 +162,6 @@ class CLIPVisionModelTaskPool(
|
|
|
159
162
|
self.fast_dev_run = RuntimeConstants().debug
|
|
160
163
|
else:
|
|
161
164
|
self.fast_dev_run = fast_dev_run
|
|
162
|
-
super().__init__(**kwargs)
|
|
163
165
|
|
|
164
166
|
def setup(self):
|
|
165
167
|
"""
|
|
@@ -220,7 +222,9 @@ class CLIPVisionModelTaskPool(
|
|
|
220
222
|
for name, dataset in self.test_datasets.items()
|
|
221
223
|
}
|
|
222
224
|
self.test_dataloaders = {
|
|
223
|
-
name: self.fabric.setup_dataloaders(
|
|
225
|
+
name: self.fabric.setup_dataloaders(
|
|
226
|
+
dataloader, move_to_device=self.move_to_device
|
|
227
|
+
)
|
|
224
228
|
for name, dataloader in self.test_dataloaders.items()
|
|
225
229
|
}
|
|
226
230
|
|
|
@@ -273,6 +277,8 @@ class CLIPVisionModelTaskPool(
|
|
|
273
277
|
task_name=task_name,
|
|
274
278
|
)
|
|
275
279
|
logits: Tensor = outputs["logits"]
|
|
280
|
+
if logits.device != targets.device:
|
|
281
|
+
targets = targets.to(logits.device)
|
|
276
282
|
|
|
277
283
|
loss = F.cross_entropy(logits, targets)
|
|
278
284
|
loss_metric.update(loss.detach().cpu())
|
|
@@ -321,7 +327,8 @@ class CLIPVisionModelTaskPool(
|
|
|
321
327
|
self.clip_model,
|
|
322
328
|
processor=self.processor,
|
|
323
329
|
)
|
|
324
|
-
|
|
330
|
+
if self.move_to_device:
|
|
331
|
+
classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
|
|
325
332
|
# collect basic model information
|
|
326
333
|
training_params, all_params = count_parameters(model)
|
|
327
334
|
report["model_info"] = {
|