fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/constants/runtime.py +4 -1
  4. fusion_bench/dataset/__init__.py +2 -0
  5. fusion_bench/dataset/clip_dataset.py +4 -72
  6. fusion_bench/dataset/image_dataset.py +44 -18
  7. fusion_bench/method/base_algorithm.py +4 -0
  8. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  9. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  10. fusion_bench/method/dop/dop.py +0 -22
  11. fusion_bench/method/dop/dop_general.py +489 -0
  12. fusion_bench/method/dop/utils.py +24 -4
  13. fusion_bench/method/emr_merging/__init__.py +1 -0
  14. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  15. fusion_bench/method/emr_merging/utils.py +162 -0
  16. fusion_bench/method/opcm/opcm.py +6 -2
  17. fusion_bench/method/opcm/opcm_general.py +356 -0
  18. fusion_bench/method/opcm/utils.py +1 -4
  19. fusion_bench/method/simple_average.py +52 -18
  20. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  21. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  22. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  23. fusion_bench/mixins/lightning_fabric.py +110 -11
  24. fusion_bench/mixins/openclip_classification.py +155 -1
  25. fusion_bench/mixins/serialization.py +1 -1
  26. fusion_bench/modelpool/base_pool.py +37 -0
  27. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  28. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  29. fusion_bench/models/hf_clip.py +20 -0
  30. fusion_bench/models/modulator/__init__.py +1 -0
  31. fusion_bench/models/modulator/base.py +123 -0
  32. fusion_bench/models/open_clip/modeling.py +61 -5
  33. fusion_bench/models/open_clip/utils.py +13 -2
  34. fusion_bench/models/parameter_dict.py +119 -29
  35. fusion_bench/models/utils.py +190 -2
  36. fusion_bench/models/wrappers/switch.py +90 -0
  37. fusion_bench/programs/base_program.py +6 -0
  38. fusion_bench/programs/fabric_fusion_program.py +4 -0
  39. fusion_bench/py.typed +1 -0
  40. fusion_bench/scripts/cli.py +25 -23
  41. fusion_bench/scripts/imgui.py +2 -2
  42. fusion_bench/scripts/webui.py +2 -2
  43. fusion_bench/taskpool/image_classification.py +270 -0
  44. fusion_bench/utils/__init__.py +20 -1
  45. fusion_bench/utils/data.py +1 -1
  46. fusion_bench/utils/dict.py +19 -0
  47. fusion_bench/utils/dtype.py +19 -0
  48. fusion_bench/utils/hydra_utils.py +75 -0
  49. fusion_bench/utils/misc.py +1 -0
  50. fusion_bench/utils/packages.py +4 -0
  51. fusion_bench/utils/parameters.py +33 -0
  52. fusion_bench/utils/rich_utils.py +42 -19
  53. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  54. fusion_bench/utils/tensorboard.py +21 -3
  55. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  56. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
  57. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  58. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  59. fusion_bench_config/README.md +9 -0
  60. fusion_bench_config/fabric/auto.yaml +1 -0
  61. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  62. fusion_bench_config/hydra/default.yaml +3 -1
  63. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  64. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  65. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  66. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  67. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  68. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  69. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  70. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,16 @@ from copy import deepcopy
3
3
  from typing import Dict, List, Mapping, Optional, Union
4
4
 
5
5
  import torch
6
- from torch import nn
6
+ from torch import Tensor, nn
7
7
 
8
8
  from fusion_bench.method.base_algorithm import BaseAlgorithm
9
9
  from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
10
10
  from fusion_bench.modelpool import BaseModelPool
11
+ from fusion_bench.models.utils import (
12
+ get_target_state_dict,
13
+ load_state_dict_into_target_modules,
14
+ validate_target_modules_equal,
15
+ )
11
16
  from fusion_bench.utils import LazyStateDict
12
17
  from fusion_bench.utils.state_dict_arithmetic import (
13
18
  state_dict_add,
@@ -21,21 +26,22 @@ log = logging.getLogger(__name__)
21
26
 
22
27
 
23
28
  def simple_average(
24
- modules: List[Union[nn.Module, StateDictType]],
25
- base_module: Optional[nn.Module] = None,
29
+ modules: List[Union[nn.Module, StateDictType, Tensor]],
30
+ base_module: Optional[Union[nn.Module, StateDictType, Tensor]] = None,
26
31
  ):
27
32
  R"""
28
33
  Averages the parameters of a list of PyTorch modules or state dictionaries.
29
34
 
30
35
  This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters.
31
36
 
37
+ If `_fusion_bench_target_modules` attribute is set on the modules, only the parameters of the specified target submodules will be averaged.
38
+
32
39
  Args:
33
- modules (List[Union[nn.Module, StateDictType]]): A list of PyTorch modules or state dictionaries.
34
- base_module (Optional[nn.Module]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
40
+ modules (List[Union[nn.Module, StateDictType, Tensor]]): A list of PyTorch modules or state dictionaries.
41
+ base_module (Optional[Union[nn.Module, StateDictType, Tensor]]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
35
42
 
36
43
  Returns:
37
- module_or_state_dict (Union[nn.Module, StateDictType]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
38
-
44
+ module_or_state_dict (Union[nn.Module, StateDictType, Tensor]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
39
45
  Examples:
40
46
  >>> import torch.nn as nn
41
47
  >>> model1 = nn.Linear(10, 10)
@@ -47,23 +53,42 @@ def simple_average(
47
53
  >>> averaged_state_dict = simple_average([state_dict1, state_dict2])
48
54
  """
49
55
  assert len(modules) > 0, "modules must be a non-empty list"
56
+ validate_target_modules_equal(modules)
57
+
50
58
  if isinstance(modules[0], nn.Module):
51
59
  if base_module is None:
52
60
  new_module = deepcopy(modules[0])
53
61
  else:
54
62
  new_module = base_module
55
- state_dict = state_dict_avg([module.state_dict() for module in modules])
56
- new_module.load_state_dict(state_dict)
63
+ state_dict = state_dict_avg(
64
+ [get_target_state_dict(module) for module in modules]
65
+ )
66
+ load_state_dict_into_target_modules(new_module, state_dict)
57
67
  return new_module
58
68
  elif isinstance(modules[0], Mapping):
59
- return state_dict_avg(modules)
69
+ # if the modules are state dicts
70
+ # compute the average state dict
71
+ avg_state_dict = state_dict_avg(modules)
72
+ # load into base_module if provided
73
+ if base_module is not None:
74
+ for k in avg_state_dict:
75
+ base_module[k] = avg_state_dict[k]
76
+ return base_module
77
+ else:
78
+ return avg_state_dict
79
+ elif isinstance(modules[0], Tensor):
80
+ mean_tensor = torch.stack(modules, dim=0).mean(dim=0)
81
+ if base_module is not None:
82
+ base_module.data = mean_tensor
83
+ return base_module
84
+ else:
85
+ return mean_tensor
86
+ else:
87
+ raise ValueError(f"Unsupported type: {type(modules[0])}")
60
88
 
61
89
 
62
90
  @auto_register_config
63
- class SimpleAverageAlgorithm(
64
- SimpleProfilerMixin,
65
- BaseAlgorithm,
66
- ):
91
+ class SimpleAverageAlgorithm(SimpleProfilerMixin, BaseAlgorithm):
67
92
  def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
68
93
  """
69
94
  Args:
@@ -87,13 +112,20 @@ class SimpleAverageAlgorithm(
87
112
  Returns:
88
113
  The fused model obtained by simple averaging.
89
114
  """
90
- if isinstance(modelpool, dict):
115
+ if not isinstance(modelpool, BaseModelPool):
91
116
  modelpool = BaseModelPool(modelpool)
92
117
 
93
118
  log.info(
94
119
  f"Fusing models using simple average on {len(modelpool.model_names)} models. "
95
120
  f"models: {modelpool.model_names}"
96
121
  )
122
+ if modelpool.has_instance_models and self.inplace:
123
+ log.warning(
124
+ "The model pool contains instance models, and inplace is set to True. "
125
+ "Therefore, the weights of the first model will be overwritten. "
126
+ "If this is desired behavior, this warning can be ignored."
127
+ )
128
+
97
129
  sd: Optional[StateDictType] = None
98
130
  forward_model = None
99
131
  merged_model_names = []
@@ -106,12 +138,12 @@ class SimpleAverageAlgorithm(
106
138
  with self.profile("merge weights"):
107
139
  if sd is None:
108
140
  # Initialize the state dictionary with the first model's state dictionary
109
- sd = model.state_dict()
141
+ sd = get_target_state_dict(model)
110
142
  forward_model = model if self.inplace else deepcopy(model)
111
143
  else:
112
144
  # Add the current model's state dictionary to the accumulated state dictionary
113
145
  sd = state_dict_add(
114
- sd, model.state_dict(), show_pbar=self.show_pbar
146
+ sd, get_target_state_dict(model), show_pbar=self.show_pbar
115
147
  )
116
148
  with self.profile("merge weights"):
117
149
  # Divide the accumulated state dictionary by the number of models to get the average
@@ -124,11 +156,13 @@ class SimpleAverageAlgorithm(
124
156
  forward_model = deepcopy(forward_model.meta_module).to_empty(
125
157
  device=forward_model._device
126
158
  )
127
- result = forward_model.load_state_dict(sd, strict=False)
159
+
160
+ result = load_state_dict_into_target_modules(forward_model, sd, strict=False)
128
161
  if result.unexpected_keys:
129
162
  raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
130
163
  if result.missing_keys:
131
164
  log.warning(f"Missing keys in state dict: {result.missing_keys}")
165
+
132
166
  # print profile report and log the merged models
133
167
  self.print_profile_summary()
134
168
  log.info(f"merged {len(merged_model_names)} models:")
@@ -50,7 +50,7 @@ def task_arithmetic_merge(
50
50
  finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
51
51
  scaling_factor (float): A factor by which the task vectors will be scaled before merging.
52
52
  inplace (bool, optional): If True, the pre-trained model will be modified in place.
53
- If False, a copy of the pre-trained model will be modified. Defaults to True.
53
+ If False, a copy of the pre-trained model will be modified. Defaults to True.
54
54
 
55
55
  Returns:
56
56
  nn.Module: The pre-trained model with the merged task vectors.
@@ -249,12 +249,13 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
249
249
  # - SVD finds the principal components (most important directions)
250
250
  # - Task vectors are reconstructed using only the most significant components
251
251
  # - The reconstructed vectors are merged (summed) to create a unified task vector
252
- new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
253
- task_vectors,
254
- exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
255
- accelerator=accelerator, # Use GPU if available
256
- return_single_task_models=self.return_single_task_models,
257
- )
252
+ with torch.no_grad():
253
+ new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
254
+ task_vectors,
255
+ exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
256
+ accelerator=accelerator, # Use GPU if available
257
+ return_single_task_models=self.return_single_task_models,
258
+ )
258
259
 
259
260
  # Handle the case where individual transformed task vectors are also returned
260
261
  if self.return_single_task_models:
@@ -311,7 +311,6 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(
311
311
 
312
312
  ###############
313
313
  #### TSV Merge Orthogonalization
314
- @torch.no_grad()
315
314
  def compute_and_sum_svd_mem_reduction(
316
315
  task_vectors: List[StateDictType],
317
316
  exclude_keys: Optional[List[str]] = None,
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import logging
3
3
  import os
4
+ import sys
4
5
  from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
5
6
 
6
7
  import lightning as L
@@ -10,18 +11,34 @@ from lightning.fabric.loggers import TensorBoardLogger
10
11
  from lightning.fabric.utilities.rank_zero import rank_zero_only
11
12
  from omegaconf import DictConfig, OmegaConf
12
13
 
14
+ from fusion_bench.constants import RuntimeConstants
13
15
  from fusion_bench.utils import import_object
16
+ from fusion_bench.utils.hydra_utils import get_hydra_output_dir
14
17
  from fusion_bench.utils.instantiate_utils import instantiate
15
18
 
16
19
  if TYPE_CHECKING:
17
20
  import lightning.fabric.loggers.tensorboard
18
21
  from lightning.fabric.strategies import FSDPStrategy
22
+ from lightning.pytorch.loggers import MLFlowLogger
23
+ from mlflow.tracking.client import MlflowClient
19
24
 
20
25
  log = logging.getLogger(__name__)
21
26
 
22
27
  TensorOrModule = TypeVar("TensorOrModule", torch.Tensor, torch.nn.Module, Any)
23
28
 
24
29
 
30
+ def _fabric_has_logger(fabric: L.Fabric) -> bool:
31
+ """
32
+ Check if the fabric has a logger.
33
+
34
+ Args:
35
+ fabric (L.Fabric): The Lightning Fabric instance.
36
+ Returns:
37
+ bool: True if the fabric has a logger, False otherwise.
38
+ """
39
+ return fabric._loggers is not None and len(fabric._loggers) > 0
40
+
41
+
25
42
  def get_policy(*args: str) -> set:
26
43
  """
27
44
  Get the policy from the provided list of policy names.
@@ -42,6 +59,21 @@ def get_size_based_auto_wrap_policy(*args, **kwargs):
42
59
  return policy
43
60
 
44
61
 
62
+ def _is_mlflow_logger(fabric: L.Fabric) -> bool:
63
+ """
64
+ Check if the fabric's logger is an instance of MLFlowLogger.
65
+
66
+ Args:
67
+ fabric (L.Fabric): The Lightning Fabric instance.
68
+
69
+ Returns:
70
+ bool: True if the logger is an instance of MLFlowLogger, False otherwise.
71
+ """
72
+ if not _fabric_has_logger(fabric):
73
+ return False
74
+ return fabric.logger.__class__.__name__ == "MLFlowLogger"
75
+
76
+
45
77
  class LightningFabricMixin:
46
78
  """
47
79
  A mixin class for integrating Lightning Fabric into a project.
@@ -78,8 +110,8 @@ class LightningFabricMixin:
78
110
  """
79
111
  if self._fabric_instance is None:
80
112
  if config.get("fabric", None) is None:
81
- log.warning("No fabric configuration found. use default settings.")
82
- self._fabric_instance = L.Fabric()
113
+ log.warning("No fabric configuration found. use default settings. By default, use 1 device.")
114
+ self._fabric_instance = L.Fabric(devices=1)
83
115
  else:
84
116
  self._fabric_instance = instantiate(config.fabric)
85
117
  if not _is_using_cli(): # if not using cli, launch the fabric
@@ -122,7 +154,10 @@ class LightningFabricMixin:
122
154
  Retrieves the log directory from the fabric's logger.
123
155
  """
124
156
  if self.fabric is not None and len(self.fabric._loggers) > 0:
125
- log_dir = self.fabric.logger.log_dir
157
+ if hasattr(self.fabric.logger, "log_dir"):
158
+ log_dir = self.fabric.logger.log_dir
159
+ else:
160
+ log_dir = None
126
161
 
127
162
  # Special handling for SwanLabLogger to get the correct log directory
128
163
  if (
@@ -131,6 +166,20 @@ class LightningFabricMixin:
131
166
  ):
132
167
  log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
133
168
 
169
+ if (
170
+ log_dir is None
171
+ and self.fabric.logger.__class__.__name__ == "MLFlowLogger"
172
+ ):
173
+ log_dir = self.fabric.logger.save_dir
174
+ if log_dir is None:
175
+ try:
176
+ log_dir = self._program.config.path.log_dir
177
+ except Exception:
178
+ log.error(
179
+ "Failed to get log_dir from program config for MLFlowLogger."
180
+ )
181
+ log_dir = "outputs"
182
+
134
183
  assert log_dir is not None, "log_dir should not be None"
135
184
  if self.fabric.is_global_zero and not os.path.exists(log_dir):
136
185
  os.makedirs(log_dir, exist_ok=True)
@@ -206,14 +255,7 @@ class LightningFabricMixin:
206
255
  Returns:
207
256
  bool: True if fast_dev_run is enabled, False otherwise.
208
257
  """
209
- if hasattr(self, "config") and self.config.get("fast_dev_run", False):
210
- return True
211
- elif hasattr(self, "_program") and self._program.config.get(
212
- "fast_dev_run", False
213
- ):
214
- return True
215
- else:
216
- return False
258
+ return RuntimeConstants().debug
217
259
 
218
260
  def log(self, name: str, value: Any, step: Optional[int] = None):
219
261
  """
@@ -252,3 +294,60 @@ class LightningFabricMixin:
252
294
  """
253
295
  for i, param_group in enumerate(optimizer.param_groups):
254
296
  self.fabric.log(name_template.format(i), param_group["lr"], step=step)
297
+
298
+ def log_artifact(self, local_path: str, artifact_path: str | None = None):
299
+ """
300
+ Logs a file as an artifact to the fabric's logger.
301
+
302
+ Args:
303
+ local_dir: The path to the directory to log as an artifact.
304
+ artifact_path: The directory within the logger's artifact storage to save the file.
305
+ """
306
+ if _is_mlflow_logger(self.fabric):
307
+ logger: "MLFlowLogger" = self.fabric.logger
308
+ experiment: "MlflowClient" = logger.experiment
309
+ experiment.log_artifact(
310
+ logger.run_id,
311
+ local_path=local_path,
312
+ artifact_path=(artifact_path),
313
+ )
314
+
315
+ def log_artifacts(self, local_dir: str, artifact_path: str | None = None):
316
+ """
317
+ Logs a directory as artifacts to the fabric's logger.
318
+
319
+ Args:
320
+ local_dir: The path to the directory to log as artifacts.
321
+ artifact_path: The directory within the logger's artifact storage to save the files.
322
+ """
323
+ if _is_mlflow_logger(self.fabric):
324
+ logger: "MLFlowLogger" = self.fabric.logger
325
+ experiment: "MlflowClient" = logger.experiment
326
+ experiment.log_artifacts(
327
+ logger.run_id,
328
+ local_dir=local_dir,
329
+ artifact_path=artifact_path,
330
+ )
331
+
332
+ def finalize(self):
333
+ """
334
+ Destructor to ensure proper cleanup of the Lightning Fabric instance.
335
+ """
336
+ if self._fabric_instance is None:
337
+ return
338
+
339
+ if _fabric_has_logger(self.fabric) and _is_mlflow_logger(self.fabric):
340
+ if sys.exc_info()[0] is None:
341
+ status = "success"
342
+ else:
343
+ status = "failed"
344
+ self.fabric.logger.finalize(status)
345
+
346
+ del self._fabric_instance
347
+ self._fabric_instance = None
348
+
349
+ def __del__(self):
350
+ """
351
+ Destructor to ensure proper cleanup of the Lightning Fabric instance.
352
+ """
353
+ self.finalize()
@@ -1,11 +1,165 @@
1
+ import functools
1
2
  import logging
3
+ from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Literal, Optional
2
4
 
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from fusion_bench.dataset.clip_dataset import CLIPDataset
3
11
  from fusion_bench.mixins import LightningFabricMixin
4
- from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
12
+ from fusion_bench.modelpool import OpenCLIPVisionModelPool
13
+ from fusion_bench.models.open_clip import (
14
+ ClassificationHead,
15
+ ImageClassifier,
16
+ ImageEncoder,
17
+ )
18
+ from fusion_bench.utils.data import InfiniteDataLoader
5
19
 
6
20
  log = logging.getLogger(__name__)
7
21
 
8
22
 
9
23
  class OpenCLIPClassificationMixin(LightningFabricMixin):
24
+
10
25
  _train_processor = None
11
26
  _test_processor = None
27
+ dataloader_kwargs: DictConfig
28
+ modelpool: OpenCLIPVisionModelPool
29
+ zero_shot_heads: Dict[str, ClassificationHead] = {}
30
+
31
+ def _init_processor(self, encoder: Optional["ImageEncoder"] = None):
32
+ """
33
+ Initialize the CLIP processors for training and testing.
34
+ """
35
+ if encoder is None:
36
+ encoder: "ImageEncoder" = self.modelpool.load_pretrained_or_first_model()
37
+ self._train_processor = encoder.train_preprocess
38
+ self._test_processor = encoder.val_preprocess
39
+ return self._train_processor, self._test_processor
40
+
41
+ def get_clip_processor(self, stage: Literal["train", "test"]):
42
+ """
43
+ Get the CLIP processor, loading it from the model pool if necessary.
44
+
45
+ Returns:
46
+ CLIPProcessor: The CLIP processor for image and text preprocessing.
47
+
48
+ Raises:
49
+ AssertionError: If the model pool is not set.
50
+ """
51
+ if stage == "train":
52
+ if self._train_processor is None:
53
+ self._init_processor()
54
+ return self._train_processor
55
+ elif stage == "test":
56
+ if self._test_processor is None:
57
+ self._init_processor()
58
+ return self._test_processor
59
+ else:
60
+ raise ValueError(f"Invalid stage: {stage}")
61
+
62
+ def setup_zero_shot_classification_head(
63
+ self,
64
+ task_names: Optional[List[str]] = None,
65
+ freeze: bool = True,
66
+ dtype: Optional[torch.dtype] = None,
67
+ ):
68
+ # check task names consistency across processes
69
+ _task_names = self.fabric.broadcast(task_names, src=0)
70
+ if not self.fabric.is_global_zero and task_names != _task_names:
71
+ raise ValueError("The `task_names` must be the same across all processes.")
72
+
73
+ for task in tqdm(
74
+ self.modelpool.model_names if task_names is None else task_names,
75
+ "Setting up zero-shot classification head",
76
+ disable=not self.fabric.is_global_zero,
77
+ ):
78
+ head = self.modelpool.load_classification_head(task)
79
+ if freeze:
80
+ head.requires_grad_(False)
81
+ if dtype is not None:
82
+ head = head.to(dtype=dtype)
83
+ self.zero_shot_heads[task] = self.to_device(head)
84
+
85
+ def set_clip_processor(self, stage: Literal["train", "test"], processor: Callable):
86
+ """
87
+ Set the CLIP processor for a specific stage.
88
+
89
+ Args:
90
+ stage (Literal["train", "test"]): The stage for which to set the processor.
91
+ processor (Callable): The CLIP processor to set.
92
+ """
93
+ if stage == "train":
94
+ self._train_processor = processor
95
+ elif stage == "test":
96
+ self._test_processor = processor
97
+ else:
98
+ raise ValueError(f"Invalid stage: {stage}")
99
+
100
+ @functools.cache
101
+ def get_shuffled_test_loader_iter(
102
+ self,
103
+ task: str,
104
+ batch_size: Optional[int] = None,
105
+ num_workers: Optional[int] = None,
106
+ **loader_kwargs,
107
+ ) -> Iterator:
108
+ """
109
+ Get an iterator for a shuffled test DataLoader.
110
+
111
+ This method creates a DataLoader for the test dataset of the specified task,
112
+ with shuffling enabled. It allows for optional customization of batch size,
113
+ number of workers, and other DataLoader keyword arguments.
114
+
115
+ Args:
116
+ task (str): The task identifier for which the test dataset is to be loaded.
117
+ batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
118
+ num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
119
+ **loader_kwargs: Additional keyword arguments to pass to the DataLoader.
120
+
121
+ Returns:
122
+ Iterator: An iterator over the shuffled test DataLoader.
123
+ """
124
+ # get dataloader kwargs
125
+ dataloader_kwargs = self.dataloader_kwargs.copy()
126
+ dataloader_kwargs["shuffle"] = True
127
+ if batch_size is not None:
128
+ dataloader_kwargs["batch_size"] = batch_size
129
+ if num_workers is not None:
130
+ dataloader_kwargs["num_workers"] = num_workers
131
+ dataloader_kwargs.update(loader_kwargs)
132
+
133
+ # get the test dataset
134
+ clip_dataset = CLIPDataset(
135
+ self.modelpool.load_test_dataset(task),
136
+ processor=self.get_clip_processor(stage="test"),
137
+ )
138
+ # create the dataloader
139
+ loader = DataLoader(clip_dataset, **dataloader_kwargs)
140
+ loader = self.fabric.setup_dataloaders(loader)
141
+ return iter(InfiniteDataLoader(loader))
142
+
143
+ def compute_logits(
144
+ self,
145
+ module: ImageClassifier,
146
+ images,
147
+ task: str,
148
+ ):
149
+ """
150
+ Compute the logits for a batch of images using the provided module and task.
151
+
152
+ Args:
153
+ module (ImageClassifier): The image classification module to use for computing logits.
154
+ images (torch.Tensor): The batch of images for which to compute logits.
155
+ task (str): The task identifier to specify which classification head to use.
156
+
157
+ Returns:
158
+ torch.Tensor: The computed logits for the input images.
159
+ """
160
+ if len(self.zero_shot_heads) == 0:
161
+ self.setup_zero_shot_classification_head()
162
+ task_head = self.zero_shot_heads[task]
163
+ features = module(images)
164
+ logits = task_head(features)
165
+ return logits
@@ -68,7 +68,7 @@ def auto_register_config(cls):
68
68
 
69
69
  Behavior:
70
70
  - **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
71
- from the __init__ method are automatically added to _config_mapping
71
+ from the __init__ method are automatically added to _config_mapping
72
72
  - **Positional Arguments**: Handled in order and mapped to corresponding parameter names
73
73
  - **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
74
74
  - **Default Values**: Applied when parameters are not provided via arguments
@@ -7,10 +7,12 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
7
7
  from torch import nn
8
8
  from torch.utils.data import Dataset
9
9
 
10
+ from fusion_bench import StateDictType, TorchModelType
10
11
  from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
11
12
  from fusion_bench.utils import (
12
13
  ValidationError,
13
14
  instantiate,
15
+ state_dict_sub,
14
16
  timeit_context,
15
17
  validate_model_name,
16
18
  )
@@ -56,6 +58,10 @@ class BaseModelPool(
56
58
  **kwargs,
57
59
  ):
58
60
  if isinstance(models, List):
61
+ log.debug(
62
+ "Initializing BaseModelPool with a list of models. "
63
+ "Converting to a dictionary with integer string keys."
64
+ )
59
65
  models = {str(model_idx): model for model_idx, model in enumerate(models)}
60
66
 
61
67
  if isinstance(models, dict):
@@ -80,6 +86,22 @@ class BaseModelPool(
80
86
  self._test_datasets = test_datasets
81
87
  super().__init__(**kwargs)
82
88
 
89
+ @property
90
+ def has_instance_models(self) -> bool:
91
+ """
92
+ Check if the model pool contains any pre-instantiated models.
93
+
94
+ Attention:
95
+ Some algorithms may modify the models in-place if they are pre-instantiated.
96
+
97
+ Returns:
98
+ bool: True if there are pre-instantiated models, False otherwise.
99
+ """
100
+ for model_cfg in self._models.values():
101
+ if isinstance(model_cfg, nn.Module):
102
+ return True
103
+ return False
104
+
83
105
  @property
84
106
  def has_pretrained(self) -> bool:
85
107
  """
@@ -328,6 +350,21 @@ class BaseModelPool(
328
350
  for model_name in self.model_names:
329
351
  yield model_name, self.load_model(model_name)
330
352
 
353
+ def load_pretrained_model_and_task_vectors(
354
+ self,
355
+ ) -> Tuple[TorchModelType, List[StateDictType]]:
356
+ pretrained_model = self.load_pretrained_model()
357
+
358
+ task_vectors = []
359
+ for model_name in self.model_names:
360
+ finetuned_model = self.load_model(model_name)
361
+ task_vector = state_dict_sub(
362
+ finetuned_model.state_dict(), pretrained_model.state_dict()
363
+ )
364
+ task_vectors.append(task_vector)
365
+
366
+ return pretrained_model, task_vectors
367
+
331
368
  @property
332
369
  def has_train_dataset(self) -> bool:
333
370
  """
@@ -98,7 +98,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
98
98
  - Load ConvNeXt models either from a pretrained checkpoint or from config.
99
99
  - Optionally adapt the classifier head to match dataset classnames.
100
100
  - Override `forward` to return logits for consistent interfaces within
101
- FusionBench.
101
+ FusionBench.
102
102
 
103
103
  See `fusion_bench.modelpool.resnet_for_image_classification` for a closely
104
104
  related ResNet-based pool with analogous behavior.
@@ -161,6 +161,9 @@ class ConvNextForImageClassificationPool(BaseModelPool):
161
161
  ).logits
162
162
  model.original_forward = original_forward
163
163
 
164
+ # Mark ConvNeXt layers for FusionBench fusion
165
+ model._fusion_bench_target_modules = ["convnext"]
166
+
164
167
  return model
165
168
 
166
169
  @override
@@ -180,7 +183,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
180
183
  - The ConvNeXt model via `model.save_pretrained`.
181
184
  - The paired image processor via `AutoImageProcessor.save_pretrained`.
182
185
  - If `algorithm_config` is provided and on rank-zero, a README model card
183
- documenting the FusionBench configuration.
186
+ documenting the FusionBench configuration.
184
187
  """
185
188
  model.save_pretrained(path)
186
189
  self.load_processor().save_pretrained(path)