fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/constants/runtime.py +4 -1
  4. fusion_bench/dataset/__init__.py +2 -0
  5. fusion_bench/dataset/clip_dataset.py +4 -72
  6. fusion_bench/dataset/image_dataset.py +44 -18
  7. fusion_bench/method/base_algorithm.py +4 -0
  8. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  9. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  10. fusion_bench/method/dop/dop.py +0 -22
  11. fusion_bench/method/dop/dop_general.py +489 -0
  12. fusion_bench/method/dop/utils.py +24 -4
  13. fusion_bench/method/emr_merging/__init__.py +1 -0
  14. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  15. fusion_bench/method/emr_merging/utils.py +162 -0
  16. fusion_bench/method/opcm/opcm.py +6 -2
  17. fusion_bench/method/opcm/opcm_general.py +356 -0
  18. fusion_bench/method/opcm/utils.py +1 -4
  19. fusion_bench/method/simple_average.py +52 -18
  20. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  21. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  22. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  23. fusion_bench/mixins/lightning_fabric.py +110 -11
  24. fusion_bench/mixins/openclip_classification.py +155 -1
  25. fusion_bench/mixins/serialization.py +1 -1
  26. fusion_bench/modelpool/base_pool.py +37 -0
  27. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  28. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  29. fusion_bench/models/hf_clip.py +20 -0
  30. fusion_bench/models/modulator/__init__.py +1 -0
  31. fusion_bench/models/modulator/base.py +123 -0
  32. fusion_bench/models/open_clip/modeling.py +61 -5
  33. fusion_bench/models/open_clip/utils.py +13 -2
  34. fusion_bench/models/parameter_dict.py +119 -29
  35. fusion_bench/models/utils.py +190 -2
  36. fusion_bench/models/wrappers/switch.py +90 -0
  37. fusion_bench/programs/base_program.py +6 -0
  38. fusion_bench/programs/fabric_fusion_program.py +4 -0
  39. fusion_bench/py.typed +1 -0
  40. fusion_bench/scripts/cli.py +25 -23
  41. fusion_bench/scripts/imgui.py +2 -2
  42. fusion_bench/scripts/webui.py +2 -2
  43. fusion_bench/taskpool/image_classification.py +270 -0
  44. fusion_bench/utils/__init__.py +20 -1
  45. fusion_bench/utils/data.py +1 -1
  46. fusion_bench/utils/dict.py +19 -0
  47. fusion_bench/utils/dtype.py +19 -0
  48. fusion_bench/utils/hydra_utils.py +75 -0
  49. fusion_bench/utils/misc.py +1 -0
  50. fusion_bench/utils/packages.py +4 -0
  51. fusion_bench/utils/parameters.py +33 -0
  52. fusion_bench/utils/rich_utils.py +42 -19
  53. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  54. fusion_bench/utils/tensorboard.py +21 -3
  55. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  56. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
  57. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  58. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  59. fusion_bench_config/README.md +9 -0
  60. fusion_bench_config/fabric/auto.yaml +1 -0
  61. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  62. fusion_bench_config/hydra/default.yaml +3 -1
  63. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  64. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  65. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  66. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  67. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  68. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  69. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  70. {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,37 @@
1
- from typing import List
1
+ from typing import Iterable, List, Optional
2
2
 
3
3
  import torch
4
4
  from torch import nn
5
+ from torch.nn.modules.module import _IncompatibleKeys
5
6
 
6
- from fusion_bench.utils.type import StateDictType
7
+ from fusion_bench.utils.dict import dict_merge
8
+ from fusion_bench.utils.type import StateDictType, TorchModelType
9
+
10
+
11
+ def is_leaf_module(module: nn.Module) -> bool:
12
+ return len(list(module.children())) == 0
13
+
14
+
15
+ def named_leaf_modules(
16
+ module: nn.Module,
17
+ prefix: str = "",
18
+ ignore_empty: bool = True,
19
+ ) -> Iterable[tuple[str, nn.Module]]:
20
+ """
21
+ Recursively find the leaf modules in a module.
22
+
23
+ Args:
24
+ module (nn.Module): PyTorch module.
25
+ prefix (str): A prefix to add to the layer names.
26
+
27
+ Returns:
28
+ Iterable[tuple[str, nn.Module]]: An iterable of (name, module) tuples for each leaf module.
29
+ """
30
+ for name, submodule in module.named_modules(prefix=prefix):
31
+ if is_leaf_module(submodule):
32
+ if ignore_empty and len(list(submodule.parameters())) == 0:
33
+ continue
34
+ yield name, submodule
7
35
 
8
36
 
9
37
  def del_attr(obj, names: List[str]):
@@ -104,3 +132,163 @@ def disable_dropout(model: torch.nn.Module):
104
132
  for module in model.modules():
105
133
  if isinstance(module, torch.nn.Dropout):
106
134
  module.p = 0
135
+
136
+
137
+ def get_target_state_dict(
138
+ module: nn.Module,
139
+ target_modules: str | Iterable[str] | None = None,
140
+ prefix: str = "",
141
+ keep_vars: bool = False,
142
+ ) -> StateDictType:
143
+ """
144
+ This function retrieves the state dictionary of specified target submodules within a given module
145
+ of a PyTorch model or merged state dictionary from multiple submodules.
146
+
147
+ For example, if a model has submodules named "layer1", "layer2", and "layer3", and you want to get the state dictionary of "layer1" and "layer3",
148
+ you can call this function with `target_modules` set to `["layer1", "layer3"]`.
149
+ The function will return a state dictionary that includes only the parameters and buffers from those specified submodules.
150
+
151
+ Args:
152
+ module (nn.Module): The PyTorch module containing the target submodules.
153
+ target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
154
+ If None, the entire module's state dictionary is returned if no special attribute is set (look up the `_fusion_bench_target_modules` attribute).
155
+ keep_vars (bool): If True, keeps the variables in the state dictionary. Default is False.
156
+
157
+ Returns:
158
+ StateDictType: The state dictionary of the specified target submodules, merged if multiple are provided.
159
+ """
160
+ if target_modules is None:
161
+ if (
162
+ hasattr(module, "_fusion_bench_target_modules")
163
+ and module._fusion_bench_target_modules is not None
164
+ ):
165
+ return get_target_state_dict(
166
+ module,
167
+ target_modules=module._fusion_bench_target_modules,
168
+ prefix=prefix,
169
+ keep_vars=keep_vars,
170
+ )
171
+ else:
172
+ return module.state_dict(prefix=prefix, keep_vars=keep_vars)
173
+
174
+ if isinstance(target_modules, str):
175
+ target_modules = [target_modules]
176
+
177
+ state_dicts = []
178
+ for target_module in target_modules:
179
+ submodule_prefix = (
180
+ f"{prefix}{target_module}." if prefix else f"{target_module}."
181
+ )
182
+ submodule = module.get_submodule(target_module)
183
+ state_dict = submodule.state_dict(prefix=submodule_prefix, keep_vars=keep_vars)
184
+ state_dicts.append(state_dict)
185
+
186
+ merged_state_dict = dict_merge(state_dicts, disjoint=True)
187
+ return merged_state_dict
188
+
189
+
190
+ def validate_target_modules_equal(modules: Iterable[nn.Module]) -> None:
191
+ """
192
+ Validates that the `_fusion_bench_target_modules` attribute is the same across all provided modules.
193
+
194
+ Args:
195
+ modules (Iterable[nn.Module]): An iterable of PyTorch modules to validate.
196
+
197
+ Raises:
198
+ ValueError: If the `_fusion_bench_target_modules` attribute differs among the modules.
199
+ """
200
+ model_iter = iter(modules)
201
+ first_module = next(model_iter)
202
+
203
+ if hasattr(first_module, "_fusion_bench_target_modules"):
204
+ target_modules = first_module._fusion_bench_target_modules
205
+ else:
206
+ # if the module does not have the attribute, set to None
207
+ target_modules = None
208
+
209
+ for module in model_iter:
210
+ if target_modules is None:
211
+ if (
212
+ hasattr(module, "_fusion_bench_target_modules")
213
+ and module._fusion_bench_target_modules != target_modules
214
+ ):
215
+ raise ValueError(
216
+ "_fusion_bench_target_modules attribute differs among the provided modules."
217
+ )
218
+ else:
219
+ if (
220
+ not hasattr(module, "_fusion_bench_target_modules")
221
+ or module._fusion_bench_target_modules != target_modules
222
+ ):
223
+ raise ValueError(
224
+ "_fusion_bench_target_modules attribute differs among the provided modules."
225
+ )
226
+
227
+
228
+ def load_state_dict_into_target_modules(
229
+ module: TorchModelType,
230
+ state_dict: StateDictType,
231
+ target_modules: str | Iterable[str] | None = None,
232
+ strict: bool = True,
233
+ assign: bool = False,
234
+ ):
235
+ """
236
+ Load a state dictionary into specified target submodules within a given module of a PyTorch model.
237
+
238
+ This function allows you to load parameters and buffers from a state dictionary into specific submodules
239
+ of a PyTorch model. If the `target_modules` argument is provided, only the specified submodules will be updated
240
+ with the corresponding entries from the state dictionary.
241
+
242
+ Args:
243
+ module (nn.Module): The PyTorch module containing the target submodules.
244
+ state_dict (StateDictType): The state dictionary containing parameters and buffers to load.
245
+ target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
246
+ If None, the entire module's state dictionary is updated if no special attribute is set
247
+ (look up the `_fusion_bench_target_modules` attribute).
248
+ strict (bool): Whether to strictly enforce that the keys in `state_dict` match the keys returned by
249
+ the module's `state_dict()` function. Default is True.
250
+ """
251
+ if target_modules is None:
252
+ if (
253
+ hasattr(module, "_fusion_bench_target_modules")
254
+ and module._fusion_bench_target_modules is not None
255
+ ):
256
+ return load_state_dict_into_target_modules(
257
+ module,
258
+ state_dict,
259
+ target_modules=module._fusion_bench_target_modules,
260
+ strict=strict,
261
+ assign=assign,
262
+ )
263
+ else:
264
+ return module.load_state_dict(state_dict, strict=strict, assign=assign)
265
+
266
+ if isinstance(target_modules, str):
267
+ target_modules = [target_modules]
268
+
269
+ assert (
270
+ len(target_modules) > 0
271
+ ), "target_modules should contain at least one module name."
272
+ results: list[_IncompatibleKeys] = []
273
+ for target_module in target_modules:
274
+ submodule_prefix = f"{target_module}."
275
+ submodule_prefix_len = len(submodule_prefix)
276
+ submodule = module.get_submodule(target_module)
277
+
278
+ # Extract the relevant portion of the state dictionary for the submodule
279
+ submodule_state_dict = {
280
+ key[submodule_prefix_len:]: value for key, value in state_dict.items()
281
+ }
282
+
283
+ # Load the extracted state dictionary into the submodule
284
+ result = submodule.load_state_dict(
285
+ submodule_state_dict, strict=strict, assign=assign
286
+ )
287
+ results.append(result)
288
+
289
+ # Merge results from all submodules
290
+ merged_result = _IncompatibleKeys(
291
+ missing_keys=[key for res in results for key in res.missing_keys],
292
+ unexpected_keys=[key for res in results for key in res.unexpected_keys],
293
+ )
294
+ return merged_result
@@ -0,0 +1,90 @@
1
+ """
2
+ This module contains a wrapper for switching between different models.
3
+
4
+ For example, it can be used to switch between different classification heads for a shared backbone.
5
+ """
6
+
7
+ import logging
8
+ from typing import Dict, Optional
9
+
10
+ from torch import nn
11
+
12
+ from fusion_bench.utils.misc import first, validate_and_suggest_corrections
13
+
14
+ __all__ = ["SwitchModule", "set_active_option"]
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ def _standardize_option_name(name: str) -> str:
20
+ """
21
+ Standardizes the option name by:
22
+
23
+ - Stripping whitespace and converting to lowercase.
24
+ - Replacing `-` with `_` if needed.
25
+ - Replacing `/` with `_` if needed.
26
+
27
+ Args:
28
+ name (str): The option name to standardize.
29
+ """
30
+ name = name.strip().lower()
31
+ name = name.replace("-", "_")
32
+ name = name.replace("/", "_")
33
+ return name
34
+
35
+
36
+ class SwitchModule(nn.Module):
37
+ """
38
+ A wrapper module that contains multiple sub-modules (options) and allows switching between them.
39
+
40
+ This is useful for multi-head models or models where different parts are activated based on the task.
41
+ """
42
+
43
+ def __init__(self, modules: Dict[str, nn.Module]):
44
+ """
45
+ Args:
46
+ modules (Dict[str, nn.Module]): A dictionary of modules to switch between.
47
+ """
48
+ super().__init__()
49
+ standardized_modules = {
50
+ _standardize_option_name(name): module for name, module in modules.items()
51
+ }
52
+ self._option_modules = nn.ModuleDict(standardized_modules)
53
+ self._active_option = first(self._option_modules.keys())
54
+
55
+ def set_active_option(self, option_name: str):
56
+ standardized_name = _standardize_option_name(option_name)
57
+ validate_and_suggest_corrections(standardized_name, self._option_modules.keys())
58
+ self._active_option = standardized_name
59
+
60
+ def forward(self, *args, **kwargs):
61
+ active_module = self._option_modules[self._active_option]
62
+ return active_module(*args, **kwargs)
63
+
64
+ def __getattr__(self, name):
65
+ try:
66
+ return super().__getattr__(name)
67
+ except AttributeError:
68
+ active_module = self._option_modules[self._active_option]
69
+ if hasattr(active_module, name):
70
+ return getattr(active_module, name)
71
+ raise
72
+
73
+
74
+ def set_active_option(module: nn.Module, option_name: str) -> list[str]:
75
+ """
76
+ Utility function to set the active option for all SwitchModule instances within a given module.
77
+
78
+ Args:
79
+ module (nn.Module): The module to set the active option for.
80
+ option_name (str): The name of the option to activate.
81
+
82
+ Returns:
83
+ list[str]: A list of names of submodules that were activated.
84
+ """
85
+ activated_submodules = []
86
+ for name, submodule in module.named_modules():
87
+ if isinstance(submodule, SwitchModule):
88
+ submodule.set_active_option(option_name)
89
+ activated_submodules.append(name)
90
+ return activated_submodules
@@ -75,6 +75,12 @@ class BaseHydraProgram(BaseYAMLSerializable):
75
75
  - FusionBench CLI documentation for program execution details
76
76
  """
77
77
 
78
+ _program = None
79
+
80
+ def __init__(self, **kwargs):
81
+ super().__init__(**kwargs)
82
+ self._program = self
83
+
78
84
  @abstractmethod
79
85
  def run(self):
80
86
  """
@@ -267,6 +267,7 @@ class FabricModelFusionProgram(
267
267
  merged_model = self.method.run(self.modelpool)
268
268
  self.method.on_run_end()
269
269
 
270
+ report = None
270
271
  if merged_model is None:
271
272
  log.info(
272
273
  "No merged model returned by the method. Skipping saving and evaluation."
@@ -293,5 +294,8 @@ class FabricModelFusionProgram(
293
294
  )
294
295
  os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
295
296
  json.dump(report, open(self.report_save_path, "w"))
297
+ self.log_artifact(local_path=self.report_save_path)
296
298
  else:
297
299
  log.info("No task pool specified. Skipping evaluation.")
300
+
301
+ return {"merged_model": merged_model, "report": report}
fusion_bench/py.typed ADDED
@@ -0,0 +1 @@
1
+
@@ -3,36 +3,21 @@
3
3
  This is the CLI script that is executed when the user runs the `fusion_bench` command.
4
4
  The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
5
5
  """
6
-
7
- import importlib
8
- import importlib.resources
9
6
  import logging
10
- import os
7
+ from typing import TYPE_CHECKING
11
8
 
12
9
  import hydra
13
10
  from omegaconf import DictConfig, OmegaConf
14
11
 
15
- from fusion_bench.constants import PROJECT_ROOT_PATH
16
- from fusion_bench.programs import BaseHydraProgram
17
12
  from fusion_bench.utils import instantiate
13
+ from fusion_bench.utils.hydra_utils import get_default_config_path
18
14
 
19
- log = logging.getLogger(__name__)
20
-
15
+ if TYPE_CHECKING:
16
+ from fusion_bench.programs import BaseHydraProgram
21
17
 
22
- def _get_default_config_path():
23
- for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
24
- for config_dir in ["config", "fusion_bench_config"]:
25
- config_path = os.path.join(config_path_root, config_dir)
26
- if os.path.exists(config_path) and os.path.isdir(config_path):
27
- return os.path.abspath(config_path)
28
- return None
18
+ log = logging.getLogger(__name__)
29
19
 
30
20
 
31
- @hydra.main(
32
- config_path=_get_default_config_path(),
33
- config_name="fabric_model_fusion",
34
- version_base=None,
35
- )
36
21
  def main(cfg: DictConfig) -> None:
37
22
  """
38
23
  Main entry point for the FusionBench command-line interface.
@@ -68,7 +53,7 @@ def main(cfg: DictConfig) -> None:
68
53
  loading the corresponding configuration files to populate the cfg parameter.
69
54
  """
70
55
  OmegaConf.resolve(cfg)
71
- program: BaseHydraProgram = instantiate(cfg)
56
+ program: "BaseHydraProgram" = instantiate(cfg)
72
57
 
73
58
  # Validate that instantiation succeeded and returned an object with 'run' method
74
59
  if not hasattr(program, "run") or not callable(getattr(program, "run")):
@@ -83,8 +68,25 @@ def main(cfg: DictConfig) -> None:
83
68
  err_msg += f"\n\nConfiguration content:\n{cfg}"
84
69
  raise TypeError(err_msg)
85
70
 
86
- program.run()
71
+ try:
72
+ program_result = program.run()
73
+ return program_result
74
+ except BaseException as e:
75
+ # Log the exception before exiting
76
+ if hasattr(program, "finalize") and callable(getattr(program, "finalize")):
77
+ program.finalize()
78
+ log.error(e, exc_info=True)
79
+ raise e
80
+
81
+
82
+ @hydra.main(
83
+ config_path=get_default_config_path(),
84
+ config_name="fabric_model_fusion",
85
+ version_base=None,
86
+ )
87
+ def _hydra_main(cfg: DictConfig) -> None:
88
+ main(cfg)
87
89
 
88
90
 
89
91
  if __name__ == "__main__":
90
- main()
92
+ _hydra_main()
@@ -9,7 +9,7 @@ import hydra
9
9
  from hydra import compose, initialize_config_dir
10
10
  from omegaconf import DictConfig, ListConfig, OmegaConf
11
11
 
12
- from fusion_bench.scripts.cli import _get_default_config_path
12
+ from fusion_bench.scripts.cli import get_default_config_path
13
13
 
14
14
  # Keeping the ConfigGroupNode and AppState classes as they are
15
15
  from fusion_bench.scripts.webui import AppState, ConfigGroupNode, priority_iterable
@@ -40,7 +40,7 @@ class App:
40
40
  if self.args.config_path:
41
41
  return Path(self.args.config_path)
42
42
  else:
43
- return _get_default_config_path()
43
+ return get_default_config_path()
44
44
 
45
45
  def generate_ui(self):
46
46
  dpg.create_context()
@@ -16,7 +16,7 @@ from hydra import compose, initialize_config_dir
16
16
  from hydra.core.hydra_config import HydraConfig
17
17
  from omegaconf import DictConfig, ListConfig, OmegaConf
18
18
 
19
- from fusion_bench.scripts.cli import _get_default_config_path
19
+ from fusion_bench.scripts.cli import get_default_config_path
20
20
 
21
21
 
22
22
  def escape_overrides(value: str) -> str:
@@ -385,7 +385,7 @@ class App:
385
385
  if self.args.config_path:
386
386
  return Path(self.args.config_path)
387
387
  else:
388
- return _get_default_config_path()
388
+ return get_default_config_path()
389
389
 
390
390
  def __getattr__(self, name):
391
391
  """