fusion-bench 0.2.23__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. fusion_bench/method/__init__.py +8 -0
  2. fusion_bench/method/ensemble.py +17 -2
  3. fusion_bench/method/linear/__init__.py +6 -2
  4. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  5. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  6. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  7. fusion_bench/method/simple_average.py +2 -2
  8. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  9. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  10. fusion_bench/method/wudi/__init__.py +1 -0
  11. fusion_bench/method/wudi/wudi.py +105 -0
  12. fusion_bench/mixins/lightning_fabric.py +4 -0
  13. fusion_bench/mixins/serialization.py +25 -78
  14. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  15. fusion_bench/models/hf_clip.py +4 -0
  16. fusion_bench/models/hf_utils.py +2 -1
  17. fusion_bench/models/model_card_templates/default.md +8 -1
  18. fusion_bench/models/wrappers/ensemble.py +136 -7
  19. fusion_bench/scripts/cli.py +2 -2
  20. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  21. fusion_bench/utils/devices.py +30 -8
  22. fusion_bench/utils/lazy_state_dict.py +3 -0
  23. fusion_bench/utils/rich_utils.py +7 -3
  24. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  25. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +37 -30
  26. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  27. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  28. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  29. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  30. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  31. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  32. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  33. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  34. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  35. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  36. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  37. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  38. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ from inspect import Parameter, _ParameterKind
6
6
  from pathlib import Path
7
7
  from typing import Dict, Mapping, Optional, Union
8
8
 
9
+ from bidict import MutableBidict, bidict
9
10
  from omegaconf import DictConfig, OmegaConf
10
11
 
11
12
  from fusion_bench.constants import FUSION_BENCH_VERSION
@@ -15,22 +16,29 @@ from fusion_bench.utils.instantiate_utils import set_print_function_call
15
16
  log = logging.getLogger(__name__)
16
17
 
17
18
  __all__ = [
18
- "YAMLSerializationMixin",
19
19
  "auto_register_config",
20
+ "YAMLSerializationMixin",
20
21
  "BaseYAMLSerializable",
21
22
  ]
22
23
 
23
24
 
24
- def _get_attr_name(config_mapping: Mapping[str, str], param_name):
25
- for attr_name, p in config_mapping.items():
26
- if p == param_name:
27
- return attr_name
28
- else:
29
- raise ValueError(f"Parameter {param_name} not found in config mapping.")
25
+ def _set_attr(self, param_name: str, value):
26
+ """
27
+ Set an attribute on the object using the parameter name from config mapping.
28
+
29
+ This function looks up the corresponding attribute name for the given parameter
30
+ name using the object's _config_mapping, then sets that attribute to the
31
+ specified value. It also logs the operation for debugging purposes.
30
32
 
33
+ Args:
34
+ self: The object instance to set the attribute on.
35
+ param_name (str): The parameter name (config key) to map to an attribute.
36
+ value: The value to assign to the attribute.
31
37
 
32
- def _set_attr(self, param_name: str, value):
33
- attr_name = _get_attr_name(self._config_mapping, param_name)
38
+ Raises:
39
+ ValueError: If the parameter name is not found in the config mapping.
40
+ """
41
+ attr_name = self._config_mapping.inverse[param_name]
34
42
  log.debug(f"set {attr_name} to {value}. Parameter name: {param_name}")
35
43
  setattr(self, attr_name, value)
36
44
 
@@ -59,37 +67,16 @@ def auto_register_config(cls):
59
67
  functionality and modified __init__ behavior.
60
68
 
61
69
  Behavior:
62
- - **Parameter Registration**: All non-variadic parameters (excluding *args, **kwargs)
70
+ - **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
63
71
  from the __init__ method are automatically added to _config_mapping
64
72
  - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
65
73
  - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
66
74
  - **Default Values**: Applied when parameters are not provided via arguments
67
75
  - **Attribute Setting**: All parameters become instance attributes accessible via dot notation
68
76
 
69
- Example:
70
- ```python
71
- @auto_register_config
72
- class MyAlgorithm(BaseYAMLSerializable):
73
- def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default", **kwargs):
74
- super().__init__(**kwargs)
75
-
76
- # All instantiation methods work automatically:
77
- algo1 = MyAlgorithm(0.01, 64) # positional args
78
- algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
79
- algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
80
-
81
- # Attributes are automatically set and can be serialized:
82
- print(algo1.learning_rate) # 0.01
83
- print(algo1.batch_size) # 64
84
- print(algo1.model_name) # "default" (from default value)
85
-
86
- config = algo1.config
87
- # DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
88
- ```
89
-
90
77
  Note:
91
78
  - The decorator wraps the original __init__ method while preserving its signature for IDE support
92
- - Parameters with *args or **kwargs signatures are ignored during registration
79
+ - Parameters with ``*args`` or ``**kwargs`` signatures are ignored during registration
93
80
  - The attributes are auto-registered, then the original __init__ method is called,
94
81
  - Type hints, method name, and other metadata are preserved using functools.wraps
95
82
  - This decorator is designed to work seamlessly with the YAML serialization system
@@ -103,7 +90,10 @@ def auto_register_config(cls):
103
90
 
104
91
  # Auto-register parameters in _config_mapping
105
92
  if not "_config_mapping" in cls.__dict__:
106
- cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
93
+ cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", bidict()))
94
+ if not isinstance(cls._config_mapping, bidict):
95
+ cls._config_mapping = bidict(cls._config_mapping)
96
+
107
97
  registered_parameters = tuple(cls._config_mapping.values())
108
98
 
109
99
  for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
@@ -116,6 +106,7 @@ def auto_register_config(cls):
116
106
  ) and (param_name not in registered_parameters):
117
107
  cls._config_mapping[param_name] = param_name
118
108
 
109
+ @wraps(original_init)
119
110
  def __init__(self, *args, **kwargs):
120
111
  log.debug(f"set attributes for {self.__class__.__name__} in {cls.__name__}")
121
112
  # auto-register the attributes based on the signature
@@ -162,33 +153,10 @@ def auto_register_config(cls):
162
153
 
163
154
  class YAMLSerializationMixin:
164
155
  _config_key: Optional[str] = None
165
- _config_mapping: Dict[str, str] = {}
156
+ _config_mapping: MutableBidict[str, str] = bidict()
166
157
  R"""
167
158
  `_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
168
159
 
169
- For example, if an algorithm class is defined as follows:
170
-
171
- ```python
172
- class SomeModelFusionAlgorithm(BaseModelFusionAlgorithm):
173
- hyper_parameter_1 = None
174
- hyper_parameter_2 = None
175
-
176
- _config_mapping = BaseModelFusionAlgorithm._config_mapping | {
177
- "hyper_parameter_1" : "hyper_param_1",
178
- "hyper_parameter_2" : "hyper_param_2",
179
- }
180
- def __init__(self, hyper_param_1: int, hyper_param_2: int):
181
- self.hyper_parameter_1 = hyper_param_1
182
- self.hyper_parameter_2 = hyper_param_2
183
- super().__init__()
184
- ```
185
-
186
- The model pool will be converted to a DictConfig as follows:
187
-
188
- ```python
189
- algorithm = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
190
- ```
191
-
192
160
  >>> algorithm.config
193
161
  DictCOnfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
194
162
 
@@ -207,17 +175,6 @@ class YAMLSerializationMixin:
207
175
  This property converts the model pool instance into a dictionary
208
176
  configuration, which can be used for serialization or other purposes.
209
177
 
210
- Example:
211
-
212
- ```python
213
- model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
214
- config = model.config
215
- print(config)
216
- # DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
217
- ```
218
-
219
- This is useful for serializing the object to a YAML file or for debugging.
220
-
221
178
  Returns:
222
179
  DictConfig: The configuration of the model pool.
223
180
  """
@@ -282,16 +239,6 @@ class YAMLSerializationMixin:
282
239
  serialization. This is how the attribute will appear in YAML output.
283
240
  value: The value to assign to the attribute.
284
241
 
285
- Example:
286
- ```python
287
- model = BaseYAMLSerializable()
288
- model.set_option("learning_rate", "lr", 0.001)
289
-
290
- # This sets model.learning_rate = 0.001
291
- # and maps it to "lr" in the config output
292
- config = model.config
293
- # config will contain: {"lr": 0.001, ...}
294
- ```
295
242
  """
296
243
  setattr(self, attr_name, value)
297
244
  self._config_mapping[attr_name] = param_name
@@ -1,5 +1,5 @@
1
1
  """
2
- Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/causal_lm
2
+ Online documentation for this module: https://tanganke.github.io/fusion_bench/modelpool/llm
3
3
  """
4
4
 
5
5
  import logging
@@ -26,6 +26,7 @@ from fusion_bench import (
26
26
  instantiate,
27
27
  parse_dtype,
28
28
  )
29
+ from fusion_bench.models.hf_utils import create_default_model_card
29
30
  from fusion_bench.utils.lazy_state_dict import LazyStateDict
30
31
 
31
32
  log = logging.getLogger(__name__)
@@ -271,13 +272,16 @@ class CausalLMPool(BaseModelPool):
271
272
  save_tokenizer: bool = False,
272
273
  tokenizer_kwargs=None,
273
274
  tokenizer: Optional[PreTrainedTokenizer] = None,
275
+ algorithm_config: Optional[DictConfig] = None,
276
+ description: Optional[str] = None,
277
+ base_model_in_modelcard: bool = True,
274
278
  **kwargs,
275
279
  ):
276
280
  """Save a model to the specified path with optional tokenizer and Hub upload.
277
281
 
278
282
  This method provides comprehensive model saving capabilities including
279
- optional tokenizer saving, dtype conversion, and Hugging Face Hub upload.
280
- The model is saved in the standard Hugging Face format.
283
+ optional tokenizer saving, dtype conversion, model card creation, and
284
+ Hugging Face Hub upload. The model is saved in the standard Hugging Face format.
281
285
 
282
286
  Args:
283
287
  model: The PreTrainedModel instance to be saved.
@@ -295,15 +299,13 @@ class CausalLMPool(BaseModelPool):
295
299
  when save_tokenizer is True.
296
300
  tokenizer: Optional pre-loaded tokenizer instance. If provided, this
297
301
  tokenizer will be saved regardless of the save_tokenizer flag.
302
+ algorithm_config: Optional DictConfig containing algorithm configuration.
303
+ If provided, a model card will be created with algorithm details.
304
+ description: Optional description for the model card. If not provided
305
+ and algorithm_config is given, a default description will be generated.
298
306
  **kwargs: Additional keyword arguments passed to the model's
299
307
  save_pretrained method.
300
308
 
301
- Side Effects:
302
- - Creates model files in the specified directory
303
- - Optionally creates tokenizer files in the same directory
304
- - May convert the model to a different dtype
305
- - May upload files to Hugging Face Hub
306
-
307
309
  Example:
308
310
  ```python
309
311
  >>> pool = CausalLMPool(models=..., tokenizer=...)
@@ -313,7 +315,9 @@ class CausalLMPool(BaseModelPool):
313
315
  ... "/path/to/save",
314
316
  ... save_tokenizer=True,
315
317
  ... model_dtype="float16",
316
- ... push_to_hub=True
318
+ ... push_to_hub=True,
319
+ ... algorithm_config=algorithm_config,
320
+ ... description="Custom merged model"
317
321
  ... )
318
322
  ```
319
323
  """
@@ -337,6 +341,24 @@ class CausalLMPool(BaseModelPool):
337
341
  **kwargs,
338
342
  )
339
343
 
344
+ # Create and save model card if algorithm_config is provided
345
+ if algorithm_config is not None:
346
+ if description is None:
347
+ description = "Model created using FusionBench."
348
+ model_card_str = create_default_model_card(
349
+ base_model=(
350
+ self.get_model_path("_pretrained_")
351
+ if base_model_in_modelcard and self.has_pretrained
352
+ else None
353
+ ),
354
+ models=[self.get_model_path(m) for m in self.model_names],
355
+ description=description,
356
+ algorithm_config=algorithm_config,
357
+ modelpool_config=self.config,
358
+ )
359
+ with open(os.path.join(path, "README.md"), "w") as f:
360
+ f.write(model_card_str)
361
+
340
362
 
341
363
  class CausalLMBackbonePool(CausalLMPool):
342
364
  """A specialized model pool that loads only the transformer backbone layers.
@@ -195,5 +195,9 @@ class HFCLIPClassifier(nn.Module):
195
195
  pass
196
196
  elif isinstance(image_embeds, BaseModelOutputWithPooling):
197
197
  image_embeds = image_embeds[1]
198
+ elif isinstance(image_embeds, dict) and "pooler_output" in image_embeds:
199
+ image_embeds = image_embeds["pooler_output"]
200
+ else:
201
+ raise ValueError("Unsupported output type from vision model outputs")
198
202
  image_embeds = self.clip_model.visual_projection(image_embeds)
199
203
  return image_embeds
@@ -143,7 +143,7 @@ def save_pretrained_with_remote_code(
143
143
 
144
144
  def create_default_model_card(
145
145
  models: list[str],
146
- *,
146
+ base_model: Optional[str] = None,
147
147
  title: str = "Deep Model Fusion",
148
148
  tags: list[str] = ["fusion-bench", "merge"],
149
149
  description=None,
@@ -154,6 +154,7 @@ def create_default_model_card(
154
154
 
155
155
  template: Template = Template(load_model_card_template("default.md"))
156
156
  card = template.render(
157
+ base_model=base_model,
157
158
  models=models,
158
159
  library_name="transformers",
159
160
  title=title,
@@ -1,5 +1,8 @@
1
1
  ---
2
2
  base_model:
3
+ {%- if base_model is not none %}
4
+ - {{ base_model }}
5
+ {%- endif %}
3
6
  {%- for model in models %}
4
7
  - {{ model }}
5
8
  {%- endfor %}
@@ -18,7 +21,11 @@ tags:
18
21
  This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
19
22
 
20
23
  The following models were included in the merge:
21
- {% for model in models %}
24
+
25
+ {% if base_model is not none %}
26
+ - base model: {{ base_model }}
27
+ {%- endif %}
28
+ {%- for model in models %}
22
29
  - {{ model }}
23
30
  {%- endfor %}
24
31
 
@@ -1,10 +1,17 @@
1
- from typing import Any, Callable, Dict, List, Union, cast
1
+ import logging
2
+ from typing import Any, Callable, Dict, Generic, List, Union, cast
2
3
 
3
4
  import numpy as np
4
5
  import torch
6
+ import torch.futures
5
7
  from omegaconf import ListConfig
6
8
  from torch import Tensor, nn
7
9
 
10
+ from fusion_bench.utils.devices import to_device
11
+ from fusion_bench.utils.type import TorchModelType
12
+
13
+ log = logging.getLogger(__name__)
14
+
8
15
 
9
16
  def aggregate_tensors(
10
17
  outputs: List[Any], aggregate_fn: Callable
@@ -58,12 +65,16 @@ def aggregate_tensors(
58
65
  raise ValueError("Unsupported type for outputs")
59
66
 
60
67
 
61
- class EnsembleModule(nn.Module):
68
+ class EnsembleModule(nn.Module, Generic[TorchModelType]):
62
69
  """
63
70
  Ensemble module that averages the outputs of multiple models.
64
71
  """
65
72
 
66
- def __init__(self, models: List[nn.Module]):
73
+ def __init__(
74
+ self,
75
+ models: List[TorchModelType],
76
+ device_map: Dict[int, Union[int, str]] | None = None,
77
+ ):
67
78
  """
68
79
  Initializes the EnsembleModule with a list of models.
69
80
 
@@ -73,6 +84,16 @@ class EnsembleModule(nn.Module):
73
84
  super().__init__()
74
85
  # TODO: distribute models to devices
75
86
  self.model_list = nn.ModuleList(models)
87
+ self.device_map = device_map
88
+ if self.device_map is not None:
89
+ self._move_models_to_devices()
90
+
91
+ def _move_models_to_devices(self):
92
+ for model_idx, device_id in self.device_map.items():
93
+ log.info(f"Moving model {model_idx} to device {device_id}")
94
+ self.model_list[model_idx] = self.model_list[model_idx].to(
95
+ device_id, non_blocking=True
96
+ )
76
97
 
77
98
  def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
78
99
  """
@@ -86,6 +107,49 @@ class EnsembleModule(nn.Module):
86
107
  """
87
108
  return torch.stack(outputs).mean(dim=0)
88
109
 
110
+ def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
111
+ """
112
+ Performs parallel forward pass using device mapping with futures.
113
+
114
+ Args:
115
+ *args: Variable length argument list.
116
+ **kwargs: Arbitrary keyword arguments.
117
+
118
+ Returns:
119
+ List[Any]: List of outputs from all models, all moved to the same device.
120
+ """
121
+ futures = []
122
+
123
+ device_data_cache = {}
124
+ for i, model in enumerate(self.model_list):
125
+ device_id = self.device_map.get(i, "cpu")
126
+
127
+ if device_id not in device_data_cache:
128
+ # Move inputs to the same device as the model
129
+ device_args = to_device(
130
+ args, device_id, copy_on_move=True, non_blocking=True
131
+ )
132
+ device_kwargs = to_device(
133
+ kwargs, device_id, copy_on_move=True, non_blocking=True
134
+ )
135
+ device_data_cache[device_id] = (device_args, device_kwargs)
136
+ else:
137
+ device_args, device_kwargs = device_data_cache[device_id]
138
+
139
+ # Create a future for asynchronous execution
140
+ future = torch.jit.fork(model, *device_args, **device_kwargs)
141
+ futures.append(future)
142
+
143
+ # Wait for all futures to complete and collect results
144
+ outputs = [torch.jit.wait(future) for future in futures]
145
+
146
+ # Move all outputs to the same device (use the device of the first model or cpu as fallback)
147
+ target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
148
+ outputs = [
149
+ to_device(output, target_device, non_blocking=True) for output in outputs
150
+ ]
151
+ return outputs
152
+
89
153
  def forward(self, *args: Any, **kwargs: Any) -> Any:
90
154
  """
91
155
  Performs a forward pass by averaging the outputs of the models.
@@ -97,20 +161,25 @@ class EnsembleModule(nn.Module):
97
161
  Returns:
98
162
  Aggregated output from the ensemble of models.
99
163
  """
100
- outputs = [model(*args, **kwargs) for model in self.model_list]
164
+ if self.device_map is None:
165
+ outputs = [model(*args, **kwargs) for model in self.model_list]
166
+ else:
167
+ # Parallel execution with device mapping
168
+ outputs = self._parallel_forward_with_device_map(*args, **kwargs)
101
169
  return aggregate_tensors(outputs, self._aggregate_tensors)
102
170
 
103
171
 
104
- class WeightedEnsembleModule(nn.Module):
172
+ class WeightedEnsembleModule(nn.Module, Generic[TorchModelType]):
105
173
  """
106
174
  Ensemble module that computes a weighted average of the outputs from multiple models.
107
175
  """
108
176
 
109
177
  def __init__(
110
178
  self,
111
- models: List[nn.Module],
179
+ models: List[TorchModelType],
112
180
  weights: List[float] | Tensor | np.ndarray,
113
181
  normalize: bool = True,
182
+ device_map: Dict[int, Union[int, str]] | None = None,
114
183
  ):
115
184
  """
116
185
  Initializes the WeightedEnsembleModule with models and their corresponding weights.
@@ -119,9 +188,12 @@ class WeightedEnsembleModule(nn.Module):
119
188
  models (List[nn.Module]): List of models to ensemble.
120
189
  weights (List[float] | Tensor | np.ndarray): Weights for each model.
121
190
  normalize (bool, optional): If True, normalizes the weights. Defaults to True.
191
+ device_map (Dict[int, Union[int, str]] | None, optional): Device mapping for parallel execution. Defaults to None.
122
192
  """
123
193
  super().__init__()
124
194
  self.model_list = nn.ModuleList(models)
195
+ self.device_map = device_map
196
+
125
197
  if isinstance(weights, (list, tuple, ListConfig)):
126
198
  weights = torch.tensor(weights)
127
199
  elif isinstance(weights, Tensor):
@@ -139,6 +211,17 @@ class WeightedEnsembleModule(nn.Module):
139
211
  weights = weights / weights.sum()
140
212
  self.register_buffer("weights", weights)
141
213
 
214
+ if self.device_map is not None:
215
+ self._move_models_to_devices()
216
+
217
+ def _move_models_to_devices(self):
218
+ """Move models to their assigned devices according to device_map."""
219
+ for model_idx, device_id in self.device_map.items():
220
+ log.info(f"Moving model {model_idx} to device {device_id}")
221
+ self.model_list[model_idx] = self.model_list[model_idx].to(
222
+ device_id, non_blocking=True
223
+ )
224
+
142
225
  def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
143
226
  """
144
227
  Aggregates a list of tensors using the provided weights.
@@ -152,6 +235,48 @@ class WeightedEnsembleModule(nn.Module):
152
235
  weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
153
236
  return (torch.stack(outputs) * weights).sum(dim=0)
154
237
 
238
+ def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
239
+ """
240
+ Performs parallel forward pass using device mapping with futures.
241
+
242
+ Args:
243
+ *args: Variable length argument list.
244
+ **kwargs: Arbitrary keyword arguments.
245
+
246
+ Returns:
247
+ List[Any]: List of outputs from all models, all moved to the same device.
248
+ """
249
+ futures = []
250
+
251
+ device_data_cache = {}
252
+ for i, model in enumerate(self.model_list):
253
+ device_id = self.device_map.get(i, "cpu")
254
+
255
+ if device_id not in device_data_cache:
256
+ # Move inputs to the same device as the model
257
+ device_args = to_device(
258
+ args, device_id, copy_on_move=True, non_blocking=True
259
+ )
260
+ device_kwargs = to_device(
261
+ kwargs, device_id, copy_on_move=True, non_blocking=True
262
+ )
263
+ device_data_cache[device_id] = (device_args, device_kwargs)
264
+ else:
265
+ device_args, device_kwargs = device_data_cache[device_id]
266
+
267
+ # Create a future for asynchronous execution
268
+ future = torch.jit.fork(model, *device_args, **device_kwargs)
269
+ futures.append(future)
270
+
271
+ # Wait for all futures to complete and collect results
272
+ outputs = [torch.jit.wait(future) for future in futures]
273
+
274
+ # Move all outputs to the same device (use the device of the first model or cpu as fallback)
275
+ target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
276
+ outputs = [to_device(output, target_device) for output in outputs]
277
+
278
+ return outputs
279
+
155
280
  def forward(self, *args: Any, **kwargs: Any) -> Any:
156
281
  """
157
282
  Performs a forward pass by computing the weighted average of the models' outputs.
@@ -163,7 +288,11 @@ class WeightedEnsembleModule(nn.Module):
163
288
  Returns:
164
289
  Weighted aggregated output from the ensemble of models.
165
290
  """
166
- outputs = [model(*args, **kwargs) for model in self.model_list]
291
+ if self.device_map is None:
292
+ outputs = [model(*args, **kwargs) for model in self.model_list]
293
+ else:
294
+ # Parallel execution with device mapping
295
+ outputs = self._parallel_forward_with_device_map(*args, **kwargs)
167
296
  return aggregate_tensors(outputs, self._aggregate_tensors)
168
297
 
169
298
 
@@ -20,8 +20,8 @@ log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  def _get_default_config_path():
23
- for config_dir in ["fusion_bench_config", "config"]:
24
- for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
23
+ for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
24
+ for config_dir in ["config", "fusion_bench_config"]:
25
25
  config_path = os.path.join(config_path_root, config_dir)
26
26
  if os.path.exists(config_path) and os.path.isdir(config_path):
27
27
  return os.path.abspath(config_path)
@@ -27,7 +27,7 @@ from tqdm.autonotebook import tqdm
27
27
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
28
28
  from transformers.models.clip.modeling_clip import CLIPVisionTransformer
29
29
 
30
- from fusion_bench import RuntimeConstants
30
+ from fusion_bench import RuntimeConstants, auto_register_config
31
31
  from fusion_bench.dataset import CLIPDataset
32
32
  from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
33
33
  from fusion_bench.models.hf_clip import HFCLIPClassifier
@@ -86,6 +86,7 @@ class LayerWiseFeatureSaver:
86
86
  torch.save(features, self.save_path)
87
87
 
88
88
 
89
+ @auto_register_config
89
90
  class CLIPVisionModelTaskPool(
90
91
  HydraConfigMixin,
91
92
  LightningFabricMixin,
@@ -134,11 +135,13 @@ class CLIPVisionModelTaskPool(
134
135
  layer_wise_feature_first_token_only: bool = True,
135
136
  layer_wise_feature_max_num: Optional[int] = None,
136
137
  fast_dev_run: Optional[bool] = None,
138
+ move_to_device: bool = True,
137
139
  **kwargs,
138
140
  ):
139
141
  """
140
142
  Initialize the CLIPVisionModelTaskPool.
141
143
  """
144
+ super().__init__(**kwargs)
142
145
  self._test_datasets = test_datasets
143
146
  self._processor = processor
144
147
  self._data_processor = data_processor
@@ -159,7 +162,6 @@ class CLIPVisionModelTaskPool(
159
162
  self.fast_dev_run = RuntimeConstants().debug
160
163
  else:
161
164
  self.fast_dev_run = fast_dev_run
162
- super().__init__(**kwargs)
163
165
 
164
166
  def setup(self):
165
167
  """
@@ -220,7 +222,9 @@ class CLIPVisionModelTaskPool(
220
222
  for name, dataset in self.test_datasets.items()
221
223
  }
222
224
  self.test_dataloaders = {
223
- name: self.fabric.setup_dataloaders(dataloader)
225
+ name: self.fabric.setup_dataloaders(
226
+ dataloader, move_to_device=self.move_to_device
227
+ )
224
228
  for name, dataloader in self.test_dataloaders.items()
225
229
  }
226
230
 
@@ -273,6 +277,8 @@ class CLIPVisionModelTaskPool(
273
277
  task_name=task_name,
274
278
  )
275
279
  logits: Tensor = outputs["logits"]
280
+ if logits.device != targets.device:
281
+ targets = targets.to(logits.device)
276
282
 
277
283
  loss = F.cross_entropy(logits, targets)
278
284
  loss_metric.update(loss.detach().cpu())
@@ -321,7 +327,8 @@ class CLIPVisionModelTaskPool(
321
327
  self.clip_model,
322
328
  processor=self.processor,
323
329
  )
324
- classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
330
+ if self.move_to_device:
331
+ classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
325
332
  # collect basic model information
326
333
  training_params, all_params = count_parameters(model)
327
334
  report["model_info"] = {