fusion-bench 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. fusion_bench/constants/runtime.py +4 -1
  2. fusion_bench/method/__init__.py +9 -1
  3. fusion_bench/method/base_algorithm.py +29 -19
  4. fusion_bench/method/classification/image_classification_finetune.py +1 -0
  5. fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
  6. fusion_bench/method/task_singular_vector/TSVM.py +7 -6
  7. fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
  8. fusion_bench/metrics/model_kinship/__init__.py +2 -0
  9. fusion_bench/metrics/model_kinship/calculate.py +77 -0
  10. fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
  11. fusion_bench/metrics/model_kinship/utility.py +184 -0
  12. fusion_bench/mixins/lightning_fabric.py +2 -8
  13. fusion_bench/mixins/openclip_classification.py +155 -1
  14. fusion_bench/modelpool/base_pool.py +1 -0
  15. fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
  16. fusion_bench/models/masks/mask_model.py +8 -2
  17. fusion_bench/models/open_clip/modeling.py +68 -5
  18. fusion_bench/models/open_clip/utils.py +13 -2
  19. fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
  20. fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
  21. fusion_bench/py.typed +1 -0
  22. fusion_bench/scripts/cli.py +21 -16
  23. fusion_bench/scripts/imgui.py +2 -2
  24. fusion_bench/scripts/webui.py +2 -2
  25. fusion_bench/utils/__init__.py +2 -0
  26. fusion_bench/utils/devices.py +3 -1
  27. fusion_bench/utils/hydra_utils.py +75 -0
  28. fusion_bench/utils/instantiate_utils.py +29 -18
  29. fusion_bench/utils/misc.py +16 -0
  30. fusion_bench/utils/parameters.py +33 -0
  31. fusion_bench/utils/rich_utils.py +165 -25
  32. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +7 -7
  33. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +41 -34
  34. fusion_bench_config/README.md +9 -0
  35. fusion_bench_config/fabric/auto.yaml +1 -0
  36. fusion_bench_config/hydra/default.yaml +3 -1
  37. fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
  38. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
  39. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
  40. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
  41. {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
@@ -77,7 +77,16 @@ def torch_load_old(save_path: str, device=None):
77
77
  return classifier
78
78
 
79
79
 
80
- def torch_save(model, save_path, save_state_dict=True):
80
+ def torch_save(model: torch.nn.Module, save_path: str, save_state_dict: bool = True):
81
+ """
82
+ Save a model to disk.
83
+
84
+ Args:
85
+ model: The model to save.
86
+ save_path (str): The path to save the model to.
87
+ save_state_dict (bool): Whether to save the state dict of the model (weights only).
88
+ If False, the entire model object is saved. Default is True.
89
+ """
81
90
  # TODO: hacky way to save state dict
82
91
  if save_state_dict and isinstance(model, torch.nn.Module):
83
92
  model = model.state_dict()
@@ -86,7 +95,9 @@ def torch_save(model, save_path, save_state_dict=True):
86
95
  torch.save(model, save_path)
87
96
 
88
97
 
89
- def torch_load(save_path, device=None):
98
+ def torch_load(
99
+ save_path: str, device: Optional[torch.device] = None
100
+ ) -> torch.nn.Module:
90
101
  model = torch.load(save_path, map_location="cpu")
91
102
  if device is not None:
92
103
  model = model.to(device)
@@ -173,6 +173,24 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
173
173
 
174
174
  @property
175
175
  def forward_model(self):
176
+ """
177
+ Get a functional model with merged parameters.
178
+
179
+ Returns a partial function that applies the pretrained model with the current
180
+ merged state dictionary. This allows for efficient forward passes without
181
+ modifying the original model's parameters.
182
+
183
+ Returns:
184
+ Callable: A partial function that can be called with (args, kwargs) to
185
+ perform forward pass with merged parameters.
186
+
187
+ Example:
188
+ ```python
189
+ # Internal usage during forward pass
190
+ forward_fn = merged_model.forward_model
191
+ output = forward_fn(args=(x,), kwargs={})
192
+ ```
193
+ """
176
194
  return functools.partial(
177
195
  functional_call,
178
196
  self.pretrained_model,
@@ -181,10 +199,30 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
181
199
  strict=self.strict,
182
200
  )
183
201
 
184
- def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
202
+ def merge_and_unload(
203
+ self,
204
+ task_vector_mask: Optional[Dict[str, Tensor]] = None,
205
+ copy: bool = False,
206
+ ) -> TorchModelType:
207
+ """
208
+ Merge models and return the final merged model.
209
+
210
+ Args:
211
+ task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
212
+ for selective parameter merging. Defaults to None.
213
+ copy (bool, optional): Whether to return a deep copy of the pretrained model.
214
+ Defaults to False. If True, the original pretrained model remains unchanged.
215
+
216
+ Returns:
217
+ TorchModelType: The pretrained model with merged parameters loaded.
218
+ """
185
219
  self.merge_weights(task_vector_mask=task_vector_mask)
186
- self.pretrained_model.load_state_dict(self._merged_state_dict)
187
- return self.pretrained_model
220
+ if copy:
221
+ model = deepcopy(self.pretrained_model)
222
+ else:
223
+ model = self.pretrained_model
224
+ model.load_state_dict(self._merged_state_dict)
225
+ return model
188
226
 
189
227
  def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
190
228
  """
@@ -16,6 +16,7 @@ outputs = merged_model(inputs)
16
16
 
17
17
  import functools
18
18
  import logging
19
+ from copy import deepcopy
19
20
  from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
20
21
 
21
22
  import torch
@@ -327,7 +328,11 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
327
328
  self._merged_state_dict = state_dict
328
329
  return state_dict
329
330
 
330
- def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
331
+ def merge_and_unload(
332
+ self,
333
+ task_vector_mask: Optional[Dict[str, Tensor]] = None,
334
+ copy: bool = False,
335
+ ) -> TorchModelType:
331
336
  """
332
337
  Merge models and return the final merged model.
333
338
 
@@ -338,6 +343,8 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
338
343
  Args:
339
344
  task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
340
345
  for selective parameter merging. Defaults to None.
346
+ copy (bool, optional): Whether to return a deep copy of the pretrained model.
347
+ Defaults to False. If True, the original pretrained model remains unchanged.
341
348
 
342
349
  Returns:
343
350
  TorchModelType: The pretrained model with merged parameters loaded.
@@ -363,8 +370,12 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
363
370
  The original pretrained model parameters will be lost.
364
371
  """
365
372
  self.merge_weights(task_vector_mask=task_vector_mask)
366
- self.pretrained_model.load_state_dict(self._merged_state_dict)
367
- return self.pretrained_model
373
+ if copy:
374
+ model = deepcopy(self.pretrained_model)
375
+ else:
376
+ model = self.pretrained_model
377
+ model.load_state_dict(self._merged_state_dict)
378
+ return model
368
379
 
369
380
  def forward(self, *args, **kwargs):
370
381
  """
fusion_bench/py.typed ADDED
@@ -0,0 +1 @@
1
+
@@ -3,33 +3,24 @@
3
3
  This is the CLI script that is executed when the user runs the `fusion_bench` command.
4
4
  The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
5
5
  """
6
-
7
- import importlib
8
- import importlib.resources
9
6
  import logging
10
- import os
7
+ from typing import TYPE_CHECKING
11
8
 
12
9
  import hydra
13
10
  from omegaconf import DictConfig, OmegaConf
14
11
 
15
12
  from fusion_bench.constants import PROJECT_ROOT_PATH
16
- from fusion_bench.programs import BaseHydraProgram
17
13
  from fusion_bench.utils import instantiate
14
+ from fusion_bench.utils.hydra_utils import get_default_config_path
18
15
 
19
- log = logging.getLogger(__name__)
20
-
16
+ if TYPE_CHECKING:
17
+ from fusion_bench.programs import BaseHydraProgram
21
18
 
22
- def _get_default_config_path():
23
- for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
24
- for config_dir in ["config", "fusion_bench_config"]:
25
- config_path = os.path.join(config_path_root, config_dir)
26
- if os.path.exists(config_path) and os.path.isdir(config_path):
27
- return os.path.abspath(config_path)
28
- return None
19
+ log = logging.getLogger(__name__)
29
20
 
30
21
 
31
22
  @hydra.main(
32
- config_path=_get_default_config_path(),
23
+ config_path=get_default_config_path(),
33
24
  config_name="fabric_model_fusion",
34
25
  version_base=None,
35
26
  )
@@ -68,7 +59,21 @@ def main(cfg: DictConfig) -> None:
68
59
  loading the corresponding configuration files to populate the cfg parameter.
69
60
  """
70
61
  OmegaConf.resolve(cfg)
71
- program: BaseHydraProgram = instantiate(cfg)
62
+ program: "BaseHydraProgram" = instantiate(cfg)
63
+
64
+ # Validate that instantiation succeeded and returned an object with 'run' method
65
+ if not hasattr(program, "run") or not callable(getattr(program, "run")):
66
+ err_msg = (
67
+ f"Expected an object with a callable 'run' method, but got {type(program).__name__}. "
68
+ "Ensure that the configuration specifies a concrete program class with '_target_'."
69
+ )
70
+ if "_target_" not in cfg:
71
+ err_msg += "\nThe '_target_' field is missing from the root configuration."
72
+ else:
73
+ err_msg += f"\nFound '_target_': {cfg._target_}"
74
+ err_msg += f"\n\nConfiguration content:\n{cfg}"
75
+ raise TypeError(err_msg)
76
+
72
77
  program.run()
73
78
 
74
79
 
@@ -9,7 +9,7 @@ import hydra
9
9
  from hydra import compose, initialize_config_dir
10
10
  from omegaconf import DictConfig, ListConfig, OmegaConf
11
11
 
12
- from fusion_bench.scripts.cli import _get_default_config_path
12
+ from fusion_bench.scripts.cli import get_default_config_path
13
13
 
14
14
  # Keeping the ConfigGroupNode and AppState classes as they are
15
15
  from fusion_bench.scripts.webui import AppState, ConfigGroupNode, priority_iterable
@@ -40,7 +40,7 @@ class App:
40
40
  if self.args.config_path:
41
41
  return Path(self.args.config_path)
42
42
  else:
43
- return _get_default_config_path()
43
+ return get_default_config_path()
44
44
 
45
45
  def generate_ui(self):
46
46
  dpg.create_context()
@@ -16,7 +16,7 @@ from hydra import compose, initialize_config_dir
16
16
  from hydra.core.hydra_config import HydraConfig
17
17
  from omegaconf import DictConfig, ListConfig, OmegaConf
18
18
 
19
- from fusion_bench.scripts.cli import _get_default_config_path
19
+ from fusion_bench.scripts.cli import get_default_config_path
20
20
 
21
21
 
22
22
  def escape_overrides(value: str) -> str:
@@ -385,7 +385,7 @@ class App:
385
385
  if self.args.config_path:
386
386
  return Path(self.args.config_path)
387
387
  else:
388
- return _get_default_config_path()
388
+ return get_default_config_path()
389
389
 
390
390
  def __getattr__(self, name):
391
391
  """
@@ -53,6 +53,7 @@ _import_structure = {
53
53
  "get_parameter_summary",
54
54
  "human_readable",
55
55
  "print_parameters",
56
+ "print_trainable_parameters",
56
57
  "state_dict_to_vector",
57
58
  "trainable_state_dict",
58
59
  "vector_to_state_dict",
@@ -138,6 +139,7 @@ if TYPE_CHECKING:
138
139
  get_parameter_summary,
139
140
  human_readable,
140
141
  print_parameters,
142
+ print_trainable_parameters,
141
143
  state_dict_to_vector,
142
144
  trainable_state_dict,
143
145
  vector_to_state_dict,
@@ -32,11 +32,13 @@ def clear_cuda_cache():
32
32
  Clears the CUDA memory cache to free up GPU memory.
33
33
  Works only if CUDA is available.
34
34
  """
35
+
35
36
  gc.collect()
36
37
  if torch.cuda.is_available():
37
38
  torch.cuda.empty_cache()
39
+ gc.collect()
38
40
  else:
39
- log.warning("CUDA is not available. No cache to clear.")
41
+ log.debug("CUDA is not available. No cache to clear.")
40
42
 
41
43
 
42
44
  def to_device(
@@ -1,4 +1,79 @@
1
+ import logging
2
+ import os
3
+
1
4
  import hydra.core.hydra_config
5
+ from hydra import compose, initialize
6
+ from omegaconf import DictConfig
7
+
8
+ from fusion_bench.constants import PROJECT_ROOT_PATH
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ def get_default_config_path():
14
+ """
15
+ Get the default configuration path by searching in common locations.
16
+ """
17
+ for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
18
+ for config_dir in ["config", "fusion_bench_config"]:
19
+ config_path = os.path.join(config_path_root, config_dir)
20
+ if os.path.exists(config_path) and os.path.isdir(config_path):
21
+ return os.path.abspath(config_path)
22
+ return None
23
+
24
+
25
+ def initialize_hydra_config(
26
+ config_name: str,
27
+ overrides: list[str] = None,
28
+ config_path: str = None,
29
+ return_hydra_config: bool = False,
30
+ ) -> DictConfig:
31
+ """
32
+ Load the Hydra configuration.
33
+
34
+ Args:
35
+ config_name (str): The name of the configuration file (without .yaml extension).
36
+ overrides (list[str]): A list of configuration overrides.
37
+ config_path (str): The path to the configuration directory. If None, it will be automatically detected.
38
+ return_hydra_config (bool): If True, return the Hydra configuration object.
39
+
40
+ Returns:
41
+ DictConfig: The loaded configuration.
42
+
43
+ Example:
44
+ >>> cfg = initialize_hydra_config(
45
+ ... config_name="fabric_model_fusion",
46
+ ... overrides=["method=dummy", "modelpool=dummy"],
47
+ ... )
48
+ >>> print(cfg.method)
49
+ """
50
+ if config_path is None:
51
+ config_path = get_default_config_path()
52
+
53
+ # check config_path validity
54
+ if config_path is None:
55
+ raise FileNotFoundError("Could not find configuration directory.")
56
+ if not os.path.isdir(config_path):
57
+ raise NotADirectoryError(
58
+ f"Configuration path {config_path} do not exists or is not a directory."
59
+ )
60
+
61
+ if overrides is None:
62
+ overrides = []
63
+
64
+ with initialize(
65
+ version_base=None,
66
+ config_path=os.path.relpath(
67
+ config_path,
68
+ start=os.path.dirname(__file__),
69
+ ),
70
+ ):
71
+ cfg = compose(
72
+ config_name=config_name,
73
+ overrides=overrides,
74
+ return_hydra_config=return_hydra_config,
75
+ )
76
+ return cfg
2
77
 
3
78
 
4
79
  def get_hydra_output_dir():
@@ -14,8 +14,8 @@ from lightning_utilities.core.rank_zero import rank_zero_only
14
14
  from omegaconf import DictConfig, OmegaConf, SCMode
15
15
  from omegaconf._utils import is_structured_config
16
16
  from rich import print
17
- from rich.panel import Panel
18
- from rich.syntax import Syntax
17
+
18
+ from fusion_bench.utils.rich_utils import print_bordered
19
19
 
20
20
  PRINT_FUNCTION_CALL = True
21
21
  """
@@ -67,12 +67,22 @@ def _resolve_callable_name(f: Callable[..., Any]) -> str:
67
67
  return full_name
68
68
 
69
69
 
70
- def _format_args_kwargs(args, kwargs):
70
+ def _get_obj_str(obj: Any) -> str:
71
+ if isinstance(obj, (str, int, float, bool, type(None))):
72
+ return repr(obj)
73
+ else:
74
+ return f"'<{type(obj).__name__} object>'"
75
+
76
+
77
+ def _format_args_kwargs(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> str:
71
78
  result_strings = []
72
79
  if len(args) > 0:
73
- result_strings.append(", ".join(repr(arg) for arg in args))
80
+ result_strings.append(", ".join(_get_obj_str(arg) for arg in args))
81
+
74
82
  if len(kwargs) > 0:
75
- result_strings.append(", ".join(f"{k}={repr(v)}" for k, v in kwargs.items()))
83
+ result_strings.append(
84
+ ", ".join(f"{k}={_get_obj_str(v)}" for k, v in kwargs.items())
85
+ )
76
86
 
77
87
  if len(result_strings) == 0:
78
88
  return ""
@@ -145,14 +155,14 @@ def _call_target(
145
155
  if _partial_:
146
156
  if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
147
157
  call_str = f"functools.partial({_resolve_callable_name(_target_)}, {_format_args_kwargs(args, kwargs)})"
148
- PRINT_FUNCTION_CALL_FUNC(
149
- Panel(
150
- Syntax(call_str, "python", theme="monokai", word_wrap=True),
151
- title="Instantiate by calling partial",
152
- border_style="cyan",
153
- )
158
+ print_bordered(
159
+ call_str,
160
+ code_style="python",
161
+ title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
162
+ style="cyan",
163
+ expand=False,
164
+ print_fn=PRINT_FUNCTION_CALL_FUNC,
154
165
  )
155
-
156
166
  if CATCH_EXCEPTION:
157
167
  try:
158
168
  return functools.partial(_target_, *args, **kwargs)
@@ -169,12 +179,13 @@ def _call_target(
169
179
  else:
170
180
  if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
171
181
  call_str = f"{_resolve_callable_name(_target_)}({_format_args_kwargs(args, kwargs)})"
172
- PRINT_FUNCTION_CALL_FUNC(
173
- Panel(
174
- Syntax(call_str, "python", theme="monokai", word_wrap=True),
175
- title="Instantiate by calling function",
176
- border_style="green",
177
- )
182
+ print_bordered(
183
+ call_str,
184
+ code_style="python",
185
+ title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
186
+ style="green",
187
+ expand=False,
188
+ print_fn=PRINT_FUNCTION_CALL_FUNC,
178
189
  )
179
190
  if CATCH_EXCEPTION:
180
191
  try:
@@ -178,3 +178,19 @@ def validate_and_suggest_corrections(
178
178
  if matches:
179
179
  msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
180
180
  raise ValueError(msg)
181
+
182
+
183
+ class DeprecationWarningMeta(type):
184
+ """
185
+ Metaclass that issues a deprecation warning whenever a class using it is instantiated.
186
+ """
187
+
188
+ def __call__(cls, *args, **kwargs):
189
+ import warnings
190
+
191
+ warnings.warn(
192
+ f"{cls.__name__} is deprecated and will be removed in a future version. ",
193
+ DeprecationWarning,
194
+ stacklevel=2,
195
+ )
196
+ return super(DeprecationWarningMeta, cls).__call__(*args, **kwargs)
@@ -10,6 +10,7 @@ from .type import StateDictType
10
10
  __all__ = [
11
11
  "count_parameters",
12
12
  "print_parameters",
13
+ "print_trainable_parameters",
13
14
  "check_parameters_all_equal",
14
15
  "get_parameter_statistics",
15
16
  "state_dict_to_vector",
@@ -282,6 +283,38 @@ def print_parameters(
282
283
  )
283
284
 
284
285
 
286
+ def print_trainable_parameters(
287
+ module: nn.Module,
288
+ is_human_readable: bool = True,
289
+ print_fn=print,
290
+ non_zero_only: bool = False,
291
+ ):
292
+ """
293
+ Print the names and number of trainable parameters in a PyTorch model.
294
+
295
+ Args:
296
+ module (nn.Module): The PyTorch model.
297
+ is_human_readable (bool, optional): Whether to print the number of parameters in a human-readable format. Defaults to True.
298
+ print_fn (callable, optional): The function to use for printing. Defaults to print.
299
+ non_zero_only (bool, optional): Whether to count only non-zero parameters. Defaults to False.
300
+
301
+ Prints:
302
+ The names and number of trainable parameters in the model.
303
+
304
+ ```python
305
+ print_trainable_parameters(model)
306
+ # weight: 1.50M parameters
307
+ # bias: 500.00K parameters
308
+ ```
309
+ """
310
+ for name, param in module.named_parameters():
311
+ if param.requires_grad:
312
+ num_params = _numel(param, non_zero_only=non_zero_only)
313
+ if is_human_readable:
314
+ num_params = human_readable(num_params)
315
+ print_fn(f"{name}: {num_params} parameters")
316
+
317
+
285
318
  def check_parameters_all_equal(
286
319
  list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
287
320
  ) -> None: