fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +12 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/ensemble.py +17 -2
  12. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  14. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  15. fusion_bench/method/linear/__init__.py +6 -2
  16. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  17. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  18. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  19. fusion_bench/method/model_stock/__init__.py +1 -0
  20. fusion_bench/method/model_stock/model_stock.py +309 -0
  21. fusion_bench/method/regmean/clip_regmean.py +3 -6
  22. fusion_bench/method/regmean/regmean.py +27 -56
  23. fusion_bench/method/regmean/utils.py +56 -0
  24. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  25. fusion_bench/method/simple_average.py +2 -2
  26. fusion_bench/method/slerp/__init__.py +1 -1
  27. fusion_bench/method/slerp/slerp.py +110 -14
  28. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  29. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  30. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  31. fusion_bench/method/wudi/__init__.py +1 -0
  32. fusion_bench/method/wudi/wudi.py +105 -0
  33. fusion_bench/mixins/clip_classification.py +26 -6
  34. fusion_bench/mixins/lightning_fabric.py +4 -0
  35. fusion_bench/mixins/serialization.py +40 -83
  36. fusion_bench/modelpool/base_pool.py +1 -1
  37. fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
  38. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  39. fusion_bench/models/hf_clip.py +4 -0
  40. fusion_bench/models/hf_utils.py +10 -4
  41. fusion_bench/models/linearized/vision_model.py +6 -6
  42. fusion_bench/models/model_card_templates/default.md +8 -1
  43. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  44. fusion_bench/models/we_moe.py +8 -8
  45. fusion_bench/models/wrappers/ensemble.py +136 -7
  46. fusion_bench/scripts/cli.py +2 -2
  47. fusion_bench/taskpool/base_pool.py +99 -17
  48. fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
  49. fusion_bench/taskpool/dummy.py +101 -13
  50. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  51. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  52. fusion_bench/utils/__init__.py +1 -0
  53. fusion_bench/utils/data.py +6 -4
  54. fusion_bench/utils/devices.py +36 -11
  55. fusion_bench/utils/dtype.py +3 -2
  56. fusion_bench/utils/lazy_state_dict.py +85 -19
  57. fusion_bench/utils/packages.py +3 -3
  58. fusion_bench/utils/parameters.py +0 -2
  59. fusion_bench/utils/rich_utils.py +7 -3
  60. fusion_bench/utils/timer.py +92 -10
  61. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  62. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
  63. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  64. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  65. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  66. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  67. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  68. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  69. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  70. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  71. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  72. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  73. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  74. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  75. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  76. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  77. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  78. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
@@ -45,21 +45,21 @@ def linearize_lora_model_(model):
45
45
 
46
46
 
47
47
  def load_fft_vision_model_hf(
48
- model_name: str, return_vison_model=True
48
+ model_name: str, return_vision_model=True
49
49
  ) -> Union[CLIPVisionTransformer, CLIPVisionModel]:
50
50
  """
51
51
  Load a CLIP vision model from Hugging Face.
52
52
 
53
53
  Args:
54
54
  model_name (str): The name of the CLIP vision model to load from Hugging Face.
55
- return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
55
+ return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
56
56
 
57
57
  Returns:
58
58
  Union[CLIPVisionTransformer, CLIPVisionModel]: The vision model.
59
59
  """
60
60
  model = CLIPVisionModel.from_pretrained(model_name)
61
61
 
62
- if return_vison_model:
62
+ if return_vision_model:
63
63
  return CLIPVisionModel.from_pretrained(model_name).vision_model
64
64
  else:
65
65
  return model
@@ -69,7 +69,7 @@ def load_lora_vision_model_hf(
69
69
  base_model_name: str,
70
70
  peft_name: str,
71
71
  merge_and_unload: bool = False,
72
- return_vison_model=True,
72
+ return_vision_model=True,
73
73
  ) -> PeftModel:
74
74
  """
75
75
  Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
@@ -80,7 +80,7 @@ def load_lora_vision_model_hf(
80
80
  base_model_name (str): The name of the base vision model to load from Hugging Face.
81
81
  peft_name (str): The name of the LoRA adaptation to apply to the base model.
82
82
  merge_and_unload (bool, optional): If True, the LoRA adaptation is merged into the base model and the LoRA layers are removed. Defaults to False.
83
- return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
83
+ return_vision_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
84
84
 
85
85
  Returns:
86
86
  PeftModel: The adapted vision model, optionally merged and unloaded.
@@ -97,7 +97,7 @@ def load_lora_vision_model_hf(
97
97
  vision_model = peft_model
98
98
 
99
99
  # Return the vision model
100
- if return_vison_model:
100
+ if return_vision_model:
101
101
  return vision_model
102
102
  else:
103
103
  model.vision_model = vision_model
@@ -1,5 +1,8 @@
1
1
  ---
2
2
  base_model:
3
+ {%- if base_model is not none %}
4
+ - {{ base_model }}
5
+ {%- endif %}
3
6
  {%- for model in models %}
4
7
  - {{ model }}
5
8
  {%- endfor %}
@@ -18,7 +21,11 @@ tags:
18
21
  This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
19
22
 
20
23
  The following models were included in the merge:
21
- {% for model in models %}
24
+
25
+ {% if base_model is not none %}
26
+ - base model: {{ base_model }}
27
+ {%- endif %}
28
+ {%- for model in models %}
22
29
  - {{ model }}
23
30
  {%- endfor %}
24
31
 
@@ -1,6 +1,7 @@
1
1
  from . import register
2
2
  from .configuration_smile_mistral import SmileMistralConfig
3
3
  from .modeling_smile_mistral import (
4
+ SmileMistralDecoderLayer,
4
5
  SmileMistralForCausalLM,
5
6
  SmileMistralModel,
6
7
  )
@@ -1,6 +1,6 @@
1
1
  import functools
2
2
  import logging
3
- from typing import List
3
+ from typing import Generic, List
4
4
 
5
5
  import torch
6
6
  import torch.func
@@ -9,7 +9,7 @@ from torch.func import functional_call
9
9
  from torch.nn import functional as F
10
10
 
11
11
  from fusion_bench.models.utils import del_attr, get_attr, set_attr
12
- from fusion_bench.utils.type import StateDictType
12
+ from fusion_bench.utils.type import StateDictType, TorchModelType
13
13
 
14
14
  log = logging.getLogger(__name__)
15
15
 
@@ -76,15 +76,15 @@ def construct_weight_ensembling_gate(
76
76
  return gate
77
77
 
78
78
 
79
- class WeightEnsemblingMoE(nn.Module):
79
+ class WeightEnsemblingMoE(nn.Module, Generic[TorchModelType]):
80
80
  # variable to store the merged state dict temporarily
81
81
  _merged_state_dict: StateDictType = None
82
82
 
83
83
  def __init__(
84
84
  self,
85
85
  hidden_size: int,
86
- base_model: nn.Module,
87
- expert_models: List[nn.Module],
86
+ base_model: TorchModelType,
87
+ expert_models: List[TorchModelType],
88
88
  init_lambda: float = 0.2,
89
89
  batch_first: bool = False,
90
90
  router_hidden_layers: int = 2,
@@ -101,8 +101,8 @@ class WeightEnsemblingMoE(nn.Module):
101
101
  Args:
102
102
 
103
103
  hidden_size (int): The size of the hidden layer in the models.
104
- base_model (nn.Module): The base model that will be used as a reference for the expert models.
105
- expert_models (List[nn.Module]): A list of expert models that will be combined.
104
+ base_model (TorchModelType): The base model that will be used as a reference for the expert models.
105
+ expert_models (List[TorchModelType]): A list of expert models that will be combined.
106
106
  init_lambda (float, optional): The initial lambda value for the weight ensembling gate. Defaults to 0.2.
107
107
  batch_first (bool, optional): If True, the input tensors are expected to have the batch size as the first dimension. Defaults to False.
108
108
  router_hidden_layers (int, optional): The number of hidden layers in the router. Defaults to 2.
@@ -145,7 +145,7 @@ class WeightEnsemblingMoE(nn.Module):
145
145
  self._merged_state_dict,
146
146
  )
147
147
 
148
- def merge_weights(self, expert_weights):
148
+ def merge_weights(self, expert_weights) -> StateDictType:
149
149
  state_dict = self.base_model.state_dict(keep_vars=True)
150
150
  for weight, task_vector in zip(expert_weights, self.task_vectors):
151
151
  for name, param in task_vector.named_parameters():
@@ -1,10 +1,17 @@
1
- from typing import Any, Callable, Dict, List, Union, cast
1
+ import logging
2
+ from typing import Any, Callable, Dict, Generic, List, Union, cast
2
3
 
3
4
  import numpy as np
4
5
  import torch
6
+ import torch.futures
5
7
  from omegaconf import ListConfig
6
8
  from torch import Tensor, nn
7
9
 
10
+ from fusion_bench.utils.devices import to_device
11
+ from fusion_bench.utils.type import TorchModelType
12
+
13
+ log = logging.getLogger(__name__)
14
+
8
15
 
9
16
  def aggregate_tensors(
10
17
  outputs: List[Any], aggregate_fn: Callable
@@ -58,12 +65,16 @@ def aggregate_tensors(
58
65
  raise ValueError("Unsupported type for outputs")
59
66
 
60
67
 
61
- class EnsembleModule(nn.Module):
68
+ class EnsembleModule(nn.Module, Generic[TorchModelType]):
62
69
  """
63
70
  Ensemble module that averages the outputs of multiple models.
64
71
  """
65
72
 
66
- def __init__(self, models: List[nn.Module]):
73
+ def __init__(
74
+ self,
75
+ models: List[TorchModelType],
76
+ device_map: Dict[int, Union[int, str]] | None = None,
77
+ ):
67
78
  """
68
79
  Initializes the EnsembleModule with a list of models.
69
80
 
@@ -73,6 +84,16 @@ class EnsembleModule(nn.Module):
73
84
  super().__init__()
74
85
  # TODO: distribute models to devices
75
86
  self.model_list = nn.ModuleList(models)
87
+ self.device_map = device_map
88
+ if self.device_map is not None:
89
+ self._move_models_to_devices()
90
+
91
+ def _move_models_to_devices(self):
92
+ for model_idx, device_id in self.device_map.items():
93
+ log.info(f"Moving model {model_idx} to device {device_id}")
94
+ self.model_list[model_idx] = self.model_list[model_idx].to(
95
+ device_id, non_blocking=True
96
+ )
76
97
 
77
98
  def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
78
99
  """
@@ -86,6 +107,49 @@ class EnsembleModule(nn.Module):
86
107
  """
87
108
  return torch.stack(outputs).mean(dim=0)
88
109
 
110
+ def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
111
+ """
112
+ Performs parallel forward pass using device mapping with futures.
113
+
114
+ Args:
115
+ *args: Variable length argument list.
116
+ **kwargs: Arbitrary keyword arguments.
117
+
118
+ Returns:
119
+ List[Any]: List of outputs from all models, all moved to the same device.
120
+ """
121
+ futures = []
122
+
123
+ device_data_cache = {}
124
+ for i, model in enumerate(self.model_list):
125
+ device_id = self.device_map.get(i, "cpu")
126
+
127
+ if device_id not in device_data_cache:
128
+ # Move inputs to the same device as the model
129
+ device_args = to_device(
130
+ args, device_id, copy_on_move=True, non_blocking=True
131
+ )
132
+ device_kwargs = to_device(
133
+ kwargs, device_id, copy_on_move=True, non_blocking=True
134
+ )
135
+ device_data_cache[device_id] = (device_args, device_kwargs)
136
+ else:
137
+ device_args, device_kwargs = device_data_cache[device_id]
138
+
139
+ # Create a future for asynchronous execution
140
+ future = torch.jit.fork(model, *device_args, **device_kwargs)
141
+ futures.append(future)
142
+
143
+ # Wait for all futures to complete and collect results
144
+ outputs = [torch.jit.wait(future) for future in futures]
145
+
146
+ # Move all outputs to the same device (use the device of the first model or cpu as fallback)
147
+ target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
148
+ outputs = [
149
+ to_device(output, target_device, non_blocking=True) for output in outputs
150
+ ]
151
+ return outputs
152
+
89
153
  def forward(self, *args: Any, **kwargs: Any) -> Any:
90
154
  """
91
155
  Performs a forward pass by averaging the outputs of the models.
@@ -97,20 +161,25 @@ class EnsembleModule(nn.Module):
97
161
  Returns:
98
162
  Aggregated output from the ensemble of models.
99
163
  """
100
- outputs = [model(*args, **kwargs) for model in self.model_list]
164
+ if self.device_map is None:
165
+ outputs = [model(*args, **kwargs) for model in self.model_list]
166
+ else:
167
+ # Parallel execution with device mapping
168
+ outputs = self._parallel_forward_with_device_map(*args, **kwargs)
101
169
  return aggregate_tensors(outputs, self._aggregate_tensors)
102
170
 
103
171
 
104
- class WeightedEnsembleModule(nn.Module):
172
+ class WeightedEnsembleModule(nn.Module, Generic[TorchModelType]):
105
173
  """
106
174
  Ensemble module that computes a weighted average of the outputs from multiple models.
107
175
  """
108
176
 
109
177
  def __init__(
110
178
  self,
111
- models: List[nn.Module],
179
+ models: List[TorchModelType],
112
180
  weights: List[float] | Tensor | np.ndarray,
113
181
  normalize: bool = True,
182
+ device_map: Dict[int, Union[int, str]] | None = None,
114
183
  ):
115
184
  """
116
185
  Initializes the WeightedEnsembleModule with models and their corresponding weights.
@@ -119,9 +188,12 @@ class WeightedEnsembleModule(nn.Module):
119
188
  models (List[nn.Module]): List of models to ensemble.
120
189
  weights (List[float] | Tensor | np.ndarray): Weights for each model.
121
190
  normalize (bool, optional): If True, normalizes the weights. Defaults to True.
191
+ device_map (Dict[int, Union[int, str]] | None, optional): Device mapping for parallel execution. Defaults to None.
122
192
  """
123
193
  super().__init__()
124
194
  self.model_list = nn.ModuleList(models)
195
+ self.device_map = device_map
196
+
125
197
  if isinstance(weights, (list, tuple, ListConfig)):
126
198
  weights = torch.tensor(weights)
127
199
  elif isinstance(weights, Tensor):
@@ -139,6 +211,17 @@ class WeightedEnsembleModule(nn.Module):
139
211
  weights = weights / weights.sum()
140
212
  self.register_buffer("weights", weights)
141
213
 
214
+ if self.device_map is not None:
215
+ self._move_models_to_devices()
216
+
217
+ def _move_models_to_devices(self):
218
+ """Move models to their assigned devices according to device_map."""
219
+ for model_idx, device_id in self.device_map.items():
220
+ log.info(f"Moving model {model_idx} to device {device_id}")
221
+ self.model_list[model_idx] = self.model_list[model_idx].to(
222
+ device_id, non_blocking=True
223
+ )
224
+
142
225
  def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
143
226
  """
144
227
  Aggregates a list of tensors using the provided weights.
@@ -152,6 +235,48 @@ class WeightedEnsembleModule(nn.Module):
152
235
  weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
153
236
  return (torch.stack(outputs) * weights).sum(dim=0)
154
237
 
238
+ def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
239
+ """
240
+ Performs parallel forward pass using device mapping with futures.
241
+
242
+ Args:
243
+ *args: Variable length argument list.
244
+ **kwargs: Arbitrary keyword arguments.
245
+
246
+ Returns:
247
+ List[Any]: List of outputs from all models, all moved to the same device.
248
+ """
249
+ futures = []
250
+
251
+ device_data_cache = {}
252
+ for i, model in enumerate(self.model_list):
253
+ device_id = self.device_map.get(i, "cpu")
254
+
255
+ if device_id not in device_data_cache:
256
+ # Move inputs to the same device as the model
257
+ device_args = to_device(
258
+ args, device_id, copy_on_move=True, non_blocking=True
259
+ )
260
+ device_kwargs = to_device(
261
+ kwargs, device_id, copy_on_move=True, non_blocking=True
262
+ )
263
+ device_data_cache[device_id] = (device_args, device_kwargs)
264
+ else:
265
+ device_args, device_kwargs = device_data_cache[device_id]
266
+
267
+ # Create a future for asynchronous execution
268
+ future = torch.jit.fork(model, *device_args, **device_kwargs)
269
+ futures.append(future)
270
+
271
+ # Wait for all futures to complete and collect results
272
+ outputs = [torch.jit.wait(future) for future in futures]
273
+
274
+ # Move all outputs to the same device (use the device of the first model or cpu as fallback)
275
+ target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
276
+ outputs = [to_device(output, target_device) for output in outputs]
277
+
278
+ return outputs
279
+
155
280
  def forward(self, *args: Any, **kwargs: Any) -> Any:
156
281
  """
157
282
  Performs a forward pass by computing the weighted average of the models' outputs.
@@ -163,7 +288,11 @@ class WeightedEnsembleModule(nn.Module):
163
288
  Returns:
164
289
  Weighted aggregated output from the ensemble of models.
165
290
  """
166
- outputs = [model(*args, **kwargs) for model in self.model_list]
291
+ if self.device_map is None:
292
+ outputs = [model(*args, **kwargs) for model in self.model_list]
293
+ else:
294
+ # Parallel execution with device mapping
295
+ outputs = self._parallel_forward_with_device_map(*args, **kwargs)
167
296
  return aggregate_tensors(outputs, self._aggregate_tensors)
168
297
 
169
298
 
@@ -20,8 +20,8 @@ log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  def _get_default_config_path():
23
- for config_dir in ["fusion_bench_config", "config"]:
24
- for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
23
+ for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
24
+ for config_dir in ["config", "fusion_bench_config"]:
25
25
  config_path = os.path.join(config_path_root, config_dir)
26
26
  if os.path.exists(config_path) and os.path.isdir(config_path):
27
27
  return os.path.abspath(config_path)
@@ -5,33 +5,115 @@ from fusion_bench.mixins import BaseYAMLSerializable
5
5
 
6
6
 
7
7
  class BaseTaskPool(BaseYAMLSerializable):
8
+ """Abstract base class for task pools in the FusionBench framework.
9
+
10
+ A task pool represents a collection of evaluation tasks that can be used to
11
+ assess model performance across multiple benchmarks or datasets. This base
12
+ class defines the common interface that all task pool implementations must
13
+ follow, ensuring consistency across different task types and evaluation
14
+ scenarios.
15
+
16
+ Task pools are designed to be configurable through YAML files and can be
17
+ used in various model fusion and evaluation workflows. They provide a
18
+ standardized way to evaluate models on multiple tasks and aggregate results.
19
+
20
+ The class inherits from BaseYAMLSerializable to support configuration
21
+ management and serialization capabilities.
22
+
23
+ Attributes:
24
+ _program: Optional program reference for execution context.
25
+ _config_key: Configuration key used for YAML configuration ("taskpool").
26
+
27
+ Abstract Methods:
28
+ evaluate: Must be implemented by subclasses to define task-specific
29
+ evaluation logic.
30
+
31
+ Example:
32
+ Implementing a custom task pool:
33
+
34
+ ```python
35
+ class MyTaskPool(BaseTaskPool):
36
+
37
+
38
+ def evaluate(self, model, **kwargs):
39
+ results = {}
40
+ for task_name in self.tasks:
41
+ # Implement task-specific evaluation
42
+ results[task_name] = self._evaluate_task(model, task_name)
43
+ return results
44
+ ```
45
+ """
46
+
8
47
  _program = None
9
48
  _config_key = "taskpool"
10
49
 
11
50
  @abstractmethod
12
51
  def evaluate(self, model: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
13
- """
14
- Evaluate the model on all tasks in the task pool, and return a report.
52
+ """Evaluate a model on all tasks in the task pool and return aggregated results.
15
53
 
16
- Take image classification as an example, the report will look like:
54
+ This abstract method defines the core evaluation interface that all task pool
55
+ implementations must provide. It should evaluate the given model on all tasks
56
+ managed by the pool and return a structured report of the results.
17
57
 
18
- ```python
19
- {
20
- "mnist": {
21
- "accuracy": 0.8,
22
- "loss": 0.2,
23
- },
24
- <task_name>: {
25
- <metric_name>: <metric_value>,
26
- ...
27
- },
28
- }
29
- ```
58
+ The evaluation process typically involves:
59
+ 1. Iterating through all tasks in the pool
60
+ 2. Running model inference on each task's dataset
61
+ 3. Computing task-specific metrics
62
+ 4. Aggregating results into a standardized report format
30
63
 
31
64
  Args:
32
- model: The model to evaluate.
65
+ model: The model to evaluate. Can be any model type (PyTorch model,
66
+ Hugging Face model, etc.) that is compatible with the specific
67
+ task pool implementation.
68
+ *args: Additional positional arguments that may be needed for
69
+ task-specific evaluation procedures.
70
+ **kwargs: Additional keyword arguments for evaluation configuration,
71
+ such as batch_size, device, evaluation metrics, etc.
33
72
 
34
73
  Returns:
35
- report (dict): A dictionary containing the results of the evaluation for each task.
74
+ Dict[str, Any]: A dictionary containing evaluation results for each task.
75
+ The structure follows the pattern:
76
+
77
+ ```python
78
+ {
79
+ "task_name_1": {
80
+ "metric_1": value,
81
+ "metric_2": value,
82
+ ...
83
+ },
84
+ "task_name_2": {
85
+ "metric_1": value,
86
+ "metric_2": value,
87
+ ...
88
+ },
89
+ ...
90
+ }
91
+ ```
92
+
93
+ Example:
94
+ For an image classification task pool:
95
+
96
+ ```python
97
+ results = task_pool.evaluate(model)
98
+ # Returns:
99
+ # {
100
+ # "mnist": {
101
+ # "accuracy": 0.95,
102
+ # "loss": 0.15,
103
+ # },
104
+ # "cifar10": {
105
+ # "accuracy": 0.87,
106
+ # "loss": 0.42,
107
+ # }
108
+ # }
109
+ ```
110
+
111
+ Raises:
112
+ NotImplementedError: This method must be implemented by subclasses.
113
+
114
+ Note:
115
+ Implementations should ensure that the returned dictionary structure
116
+ is consistent and that metric names are standardized across similar
117
+ task types to enable meaningful comparison and aggregation.
36
118
  """
37
119
  pass
@@ -27,7 +27,7 @@ from tqdm.autonotebook import tqdm
27
27
  from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
28
28
  from transformers.models.clip.modeling_clip import CLIPVisionTransformer
29
29
 
30
- from fusion_bench import RuntimeConstants
30
+ from fusion_bench import RuntimeConstants, auto_register_config
31
31
  from fusion_bench.dataset import CLIPDataset
32
32
  from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
33
33
  from fusion_bench.models.hf_clip import HFCLIPClassifier
@@ -86,6 +86,7 @@ class LayerWiseFeatureSaver:
86
86
  torch.save(features, self.save_path)
87
87
 
88
88
 
89
+ @auto_register_config
89
90
  class CLIPVisionModelTaskPool(
90
91
  HydraConfigMixin,
91
92
  LightningFabricMixin,
@@ -134,11 +135,13 @@ class CLIPVisionModelTaskPool(
134
135
  layer_wise_feature_first_token_only: bool = True,
135
136
  layer_wise_feature_max_num: Optional[int] = None,
136
137
  fast_dev_run: Optional[bool] = None,
138
+ move_to_device: bool = True,
137
139
  **kwargs,
138
140
  ):
139
141
  """
140
142
  Initialize the CLIPVisionModelTaskPool.
141
143
  """
144
+ super().__init__(**kwargs)
142
145
  self._test_datasets = test_datasets
143
146
  self._processor = processor
144
147
  self._data_processor = data_processor
@@ -159,7 +162,6 @@ class CLIPVisionModelTaskPool(
159
162
  self.fast_dev_run = RuntimeConstants().debug
160
163
  else:
161
164
  self.fast_dev_run = fast_dev_run
162
- super().__init__(**kwargs)
163
165
 
164
166
  def setup(self):
165
167
  """
@@ -220,7 +222,9 @@ class CLIPVisionModelTaskPool(
220
222
  for name, dataset in self.test_datasets.items()
221
223
  }
222
224
  self.test_dataloaders = {
223
- name: self.fabric.setup_dataloaders(dataloader)
225
+ name: self.fabric.setup_dataloaders(
226
+ dataloader, move_to_device=self.move_to_device
227
+ )
224
228
  for name, dataloader in self.test_dataloaders.items()
225
229
  }
226
230
 
@@ -273,6 +277,8 @@ class CLIPVisionModelTaskPool(
273
277
  task_name=task_name,
274
278
  )
275
279
  logits: Tensor = outputs["logits"]
280
+ if logits.device != targets.device:
281
+ targets = targets.to(logits.device)
276
282
 
277
283
  loss = F.cross_entropy(logits, targets)
278
284
  loss_metric.update(loss.detach().cpu())
@@ -309,7 +315,7 @@ class CLIPVisionModelTaskPool(
309
315
  self.setup()
310
316
 
311
317
  report = {}
312
- # CLIPVisionModel works the same with CLIPVisonTransformer, so we can use it directly
318
+ # CLIPVisionModel works the same with CLIPVisionTransformer, so we can use it directly
313
319
  if hasattr(model, "is_surgery_model") and model.is_surgery_model:
314
320
  log.info("running evaluation on a surgery model.")
315
321
  model: "SurgeryModelWrapper" = model
@@ -321,7 +327,8 @@ class CLIPVisionModelTaskPool(
321
327
  self.clip_model,
322
328
  processor=self.processor,
323
329
  )
324
- classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
330
+ if self.move_to_device:
331
+ classifier = cast(HFCLIPClassifier, self.fabric.to_device(classifier))
325
332
  # collect basic model information
326
333
  training_params, all_params = count_parameters(model)
327
334
  report["model_info"] = {