fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. fusion_bench/__init__.py +6 -0
  2. fusion_bench/__main__.py +2 -2
  3. fusion_bench/dataset/__init__.py +2 -0
  4. fusion_bench/dataset/clip_dataset.py +4 -72
  5. fusion_bench/dataset/image_dataset.py +44 -18
  6. fusion_bench/method/base_algorithm.py +4 -0
  7. fusion_bench/method/dop/dop.py +0 -22
  8. fusion_bench/method/dop/dop_general.py +489 -0
  9. fusion_bench/method/dop/utils.py +24 -4
  10. fusion_bench/method/emr_merging/__init__.py +1 -0
  11. fusion_bench/method/emr_merging/emr_merging.py +53 -0
  12. fusion_bench/method/emr_merging/utils.py +162 -0
  13. fusion_bench/method/opcm/opcm.py +6 -2
  14. fusion_bench/method/opcm/opcm_general.py +356 -0
  15. fusion_bench/method/opcm/utils.py +1 -4
  16. fusion_bench/method/simple_average.py +52 -18
  17. fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
  18. fusion_bench/mixins/lightning_fabric.py +108 -3
  19. fusion_bench/mixins/serialization.py +1 -1
  20. fusion_bench/modelpool/base_pool.py +37 -1
  21. fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
  22. fusion_bench/models/hf_clip.py +20 -0
  23. fusion_bench/models/modulator/__init__.py +1 -0
  24. fusion_bench/models/modulator/base.py +123 -0
  25. fusion_bench/models/parameter_dict.py +119 -29
  26. fusion_bench/models/utils.py +190 -2
  27. fusion_bench/models/wrappers/switch.py +90 -0
  28. fusion_bench/programs/base_program.py +6 -0
  29. fusion_bench/programs/fabric_fusion_program.py +4 -0
  30. fusion_bench/scripts/cli.py +19 -8
  31. fusion_bench/taskpool/image_classification.py +270 -0
  32. fusion_bench/utils/__init__.py +18 -1
  33. fusion_bench/utils/data.py +1 -1
  34. fusion_bench/utils/dict.py +19 -0
  35. fusion_bench/utils/dtype.py +19 -0
  36. fusion_bench/utils/misc.py +1 -0
  37. fusion_bench/utils/packages.py +4 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +183 -1
  39. fusion_bench/utils/tensorboard.py +21 -3
  40. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
  41. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
  42. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
  43. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
  44. fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
  45. fusion_bench_config/method/dop/dop_general.yaml +33 -0
  46. fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
  47. fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
  48. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
  49. fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
  50. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
  51. {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
- from typing import List, Mapping, Optional, Tuple
1
+ from typing import Iterator, List, Mapping, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
  from torch import nn
5
5
 
6
- __all__ = "ParamterDictModel"
6
+ __all__ = ["ParameterDictModel"]
7
7
 
8
8
 
9
- def _set_attr(
9
+ def set_nested_attr(
10
10
  obj,
11
11
  names: List[str],
12
12
  val,
@@ -27,7 +27,7 @@ def _set_attr(
27
27
  else:
28
28
  if check_parent and not hasattr(obj, names[0]):
29
29
  setattr(obj, names[0], parent_builder())
30
- _set_attr(
30
+ set_nested_attr(
31
31
  getattr(obj, names[0]),
32
32
  names[1:],
33
33
  val,
@@ -36,7 +36,7 @@ def _set_attr(
36
36
  )
37
37
 
38
38
 
39
- def has_attr(obj, names: List[str]):
39
+ def has_nested_attr(obj, names: List[str]):
40
40
  """
41
41
  Checks if an attribute exists in an object recursively.
42
42
 
@@ -50,26 +50,49 @@ def has_attr(obj, names: List[str]):
50
50
  if len(names) == 1:
51
51
  return hasattr(obj, names[0])
52
52
  else:
53
- return has_attr(getattr(obj, names[0]), names[1:])
53
+ if not hasattr(obj, names[0]):
54
+ return False
55
+ return has_nested_attr(getattr(obj, names[0]), names[1:])
54
56
 
55
57
 
56
58
  class ParameterDictModel(nn.Module):
57
59
  """
58
- This model is used to create a model with parameters from a dictionary.
59
- It behaves like a normal `nn.ParameterDict`, but support keys with dots.
60
+ A module that stores parameters in a nested dictionary structure.
61
+
62
+ This model behaves similarly to `nn.ParameterDict`, but supports hierarchical keys
63
+ with dots (e.g., "layer1.weight"). Parameters are stored as nested attributes,
64
+ allowing for structured parameter access and manipulation.
65
+
66
+ Example:
67
+ >>> params = {
68
+ ... "encoder.weight": nn.Parameter(torch.randn(10, 5)),
69
+ ... "decoder.bias": nn.Parameter(torch.randn(5)),
70
+ ... }
71
+ >>> model = ParameterDictModel(params)
72
+ >>> model["encoder.weight"].shape
73
+ torch.Size([10, 5])
74
+ >>> "encoder.weight" in model
75
+ True
60
76
  """
61
77
 
62
78
  def __init__(
63
79
  self,
64
- parameters: Optional[Mapping[str, nn.Parameter]] = None,
65
- ):
80
+ parameters: Optional[Mapping[str, Union[nn.Parameter, torch.Tensor]]] = None,
81
+ ) -> None:
82
+ """
83
+ Args:
84
+ parameters: Optional mapping of parameter names to parameter tensors.
85
+ Keys can contain dots to create nested structures.
86
+ Values must be `nn.Parameter` or `nn.Buffer` instances.
87
+ """
88
+
66
89
  super().__init__()
67
90
  if parameters is not None:
68
91
  for name, param in parameters.items():
69
92
  assert isinstance(
70
93
  param, (nn.Parameter, nn.Buffer)
71
94
  ), f"{name} is not a nn.Parameter or nn.Buffer"
72
- _set_attr(
95
+ set_nested_attr(
73
96
  self,
74
97
  name.split("."),
75
98
  param,
@@ -77,12 +100,13 @@ class ParameterDictModel(nn.Module):
77
100
  parent_builder=__class__,
78
101
  )
79
102
 
80
- def __repr__(self):
103
+ def __repr__(self) -> str:
81
104
  """
82
105
  Generate a string representation of the model's parameters.
83
106
 
84
107
  Returns:
85
- str: A string representation of the model's parameters.
108
+ A string representation of the model's parameters in the format:
109
+ "ParameterDictModel(name1: shape1, name2: shape2, ...)"
86
110
  """
87
111
  param_reprs = []
88
112
  for name, param in self.named_parameters():
@@ -90,32 +114,98 @@ class ParameterDictModel(nn.Module):
90
114
  param_reprs.append(param_repr)
91
115
  return f"{self.__class__.__name__}({', '.join(param_reprs)})"
92
116
 
93
- def __getitem__(self, key: str):
94
- if not has_attr(self, key.split(".")):
117
+ def __iter__(self) -> Iterator[str]:
118
+ """
119
+ Iterate over the model's parameters.
120
+
121
+ Yields:
122
+ Tuples of (parameter name, parameter tensor).
123
+ """
124
+ yield from self.keys()
125
+
126
+ def __getitem__(
127
+ self, key: str
128
+ ) -> Union[nn.Parameter, torch.Tensor, "ParameterDictModel"]:
129
+ """
130
+ Retrieve a parameter or nested submodule by key.
131
+
132
+ Args:
133
+ key: Parameter name, which can contain dots for nested access.
134
+
135
+ Returns:
136
+ The parameter, tensor, or nested ParameterDictModel at the specified key.
137
+
138
+ Raises:
139
+ KeyError: If the key is not found in the model.
140
+ """
141
+ assert isinstance(
142
+ key, str
143
+ ), f"Key must be a string, but got {type(key)}: {key}."
144
+ if not has_nested_attr(self, key.split(".")):
95
145
  raise KeyError(f"Key {key} not found in {self}")
96
- key = key.split(".")
146
+ key_parts = key.split(".")
97
147
  obj = self
98
- for k in key:
148
+ for k in key_parts:
99
149
  obj = getattr(obj, k)
100
150
  return obj
101
151
 
102
- def __setitem__(self, key: str, value: nn.Parameter):
103
- if not has_attr(self, key.split(".")):
104
- _set_attr(self, key.split("."), value, check_parent=True)
152
+ def __setitem__(self, key: str, value: Union[nn.Parameter, torch.Tensor]) -> None:
153
+ """
154
+ Set a parameter at the specified key, creating nested structure if needed.
155
+
156
+ Args:
157
+ key: Parameter name, which can contain dots for nested assignment.
158
+ value: Parameter or tensor to assign.
159
+ """
160
+ if not has_nested_attr(self, key.split(".")):
161
+ set_nested_attr(self, key.split("."), value, check_parent=True)
105
162
  else:
106
- _set_attr(self, key.split("."), value, check_parent=False)
163
+ set_nested_attr(self, key.split("."), value, check_parent=False)
164
+
165
+ def __contains__(self, key: str) -> bool:
166
+ """
167
+ Check if a parameter key exists in the model.
107
168
 
108
- def __contains__(self, key: str):
109
- return has_attr(self, key.split("."))
169
+ Args:
170
+ key: Parameter name, which can contain dots for nested checking.
171
+
172
+ Returns:
173
+ True if the key exists, False otherwise.
174
+ """
175
+ return has_nested_attr(self, key.split("."))
110
176
 
111
177
  def keys(self):
112
- return [name for name, _ in self.named_parameters()]
178
+ """
179
+ Return a list of all parameter names in the model.
180
+
181
+ Returns:
182
+ List of parameter names (including nested names with dots).
183
+ """
184
+ return self.state_dict().keys()
185
+
186
+ def items(self):
187
+ """
188
+ Return a list of (name, parameter) tuples.
189
+
190
+ Returns:
191
+ List of tuples containing parameter names and their corresponding tensors.
192
+ """
193
+ yield from self.state_dict().items()
113
194
 
114
- def items(self) -> List[Tuple[str, nn.Parameter]]:
115
- return [(name, self[name]) for name in self.keys()]
195
+ def values(self):
196
+ """
197
+ Return a list of all parameter values in the model.
116
198
 
117
- def values(self) -> List[nn.Parameter]:
118
- return [self[name] for name in self.keys()]
199
+ Returns:
200
+ List of parameter tensors.
201
+ """
202
+ yield from self.state_dict().values()
119
203
 
120
- def __len__(self):
204
+ def __len__(self) -> int:
205
+ """
206
+ Return the number of parameters in the model.
207
+
208
+ Returns:
209
+ The total number of parameters.
210
+ """
121
211
  return len(self.keys())
@@ -1,9 +1,37 @@
1
- from typing import List
1
+ from typing import Iterable, List, Optional
2
2
 
3
3
  import torch
4
4
  from torch import nn
5
+ from torch.nn.modules.module import _IncompatibleKeys
5
6
 
6
- from fusion_bench.utils.type import StateDictType
7
+ from fusion_bench.utils.dict import dict_merge
8
+ from fusion_bench.utils.type import StateDictType, TorchModelType
9
+
10
+
11
+ def is_leaf_module(module: nn.Module) -> bool:
12
+ return len(list(module.children())) == 0
13
+
14
+
15
+ def named_leaf_modules(
16
+ module: nn.Module,
17
+ prefix: str = "",
18
+ ignore_empty: bool = True,
19
+ ) -> Iterable[tuple[str, nn.Module]]:
20
+ """
21
+ Recursively find the leaf modules in a module.
22
+
23
+ Args:
24
+ module (nn.Module): PyTorch module.
25
+ prefix (str): A prefix to add to the layer names.
26
+
27
+ Returns:
28
+ Iterable[tuple[str, nn.Module]]: An iterable of (name, module) tuples for each leaf module.
29
+ """
30
+ for name, submodule in module.named_modules(prefix=prefix):
31
+ if is_leaf_module(submodule):
32
+ if ignore_empty and len(list(submodule.parameters())) == 0:
33
+ continue
34
+ yield name, submodule
7
35
 
8
36
 
9
37
  def del_attr(obj, names: List[str]):
@@ -104,3 +132,163 @@ def disable_dropout(model: torch.nn.Module):
104
132
  for module in model.modules():
105
133
  if isinstance(module, torch.nn.Dropout):
106
134
  module.p = 0
135
+
136
+
137
+ def get_target_state_dict(
138
+ module: nn.Module,
139
+ target_modules: str | Iterable[str] | None = None,
140
+ prefix: str = "",
141
+ keep_vars: bool = False,
142
+ ) -> StateDictType:
143
+ """
144
+ This function retrieves the state dictionary of specified target submodules within a given module
145
+ of a PyTorch model or merged state dictionary from multiple submodules.
146
+
147
+ For example, if a model has submodules named "layer1", "layer2", and "layer3", and you want to get the state dictionary of "layer1" and "layer3",
148
+ you can call this function with `target_modules` set to `["layer1", "layer3"]`.
149
+ The function will return a state dictionary that includes only the parameters and buffers from those specified submodules.
150
+
151
+ Args:
152
+ module (nn.Module): The PyTorch module containing the target submodules.
153
+ target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
154
+ If None, the entire module's state dictionary is returned if no special attribute is set (look up the `_fusion_bench_target_modules` attribute).
155
+ keep_vars (bool): If True, keeps the variables in the state dictionary. Default is False.
156
+
157
+ Returns:
158
+ StateDictType: The state dictionary of the specified target submodules, merged if multiple are provided.
159
+ """
160
+ if target_modules is None:
161
+ if (
162
+ hasattr(module, "_fusion_bench_target_modules")
163
+ and module._fusion_bench_target_modules is not None
164
+ ):
165
+ return get_target_state_dict(
166
+ module,
167
+ target_modules=module._fusion_bench_target_modules,
168
+ prefix=prefix,
169
+ keep_vars=keep_vars,
170
+ )
171
+ else:
172
+ return module.state_dict(prefix=prefix, keep_vars=keep_vars)
173
+
174
+ if isinstance(target_modules, str):
175
+ target_modules = [target_modules]
176
+
177
+ state_dicts = []
178
+ for target_module in target_modules:
179
+ submodule_prefix = (
180
+ f"{prefix}{target_module}." if prefix else f"{target_module}."
181
+ )
182
+ submodule = module.get_submodule(target_module)
183
+ state_dict = submodule.state_dict(prefix=submodule_prefix, keep_vars=keep_vars)
184
+ state_dicts.append(state_dict)
185
+
186
+ merged_state_dict = dict_merge(state_dicts, disjoint=True)
187
+ return merged_state_dict
188
+
189
+
190
+ def validate_target_modules_equal(modules: Iterable[nn.Module]) -> None:
191
+ """
192
+ Validates that the `_fusion_bench_target_modules` attribute is the same across all provided modules.
193
+
194
+ Args:
195
+ modules (Iterable[nn.Module]): An iterable of PyTorch modules to validate.
196
+
197
+ Raises:
198
+ ValueError: If the `_fusion_bench_target_modules` attribute differs among the modules.
199
+ """
200
+ model_iter = iter(modules)
201
+ first_module = next(model_iter)
202
+
203
+ if hasattr(first_module, "_fusion_bench_target_modules"):
204
+ target_modules = first_module._fusion_bench_target_modules
205
+ else:
206
+ # if the module does not have the attribute, set to None
207
+ target_modules = None
208
+
209
+ for module in model_iter:
210
+ if target_modules is None:
211
+ if (
212
+ hasattr(module, "_fusion_bench_target_modules")
213
+ and module._fusion_bench_target_modules != target_modules
214
+ ):
215
+ raise ValueError(
216
+ "_fusion_bench_target_modules attribute differs among the provided modules."
217
+ )
218
+ else:
219
+ if (
220
+ not hasattr(module, "_fusion_bench_target_modules")
221
+ or module._fusion_bench_target_modules != target_modules
222
+ ):
223
+ raise ValueError(
224
+ "_fusion_bench_target_modules attribute differs among the provided modules."
225
+ )
226
+
227
+
228
+ def load_state_dict_into_target_modules(
229
+ module: TorchModelType,
230
+ state_dict: StateDictType,
231
+ target_modules: str | Iterable[str] | None = None,
232
+ strict: bool = True,
233
+ assign: bool = False,
234
+ ):
235
+ """
236
+ Load a state dictionary into specified target submodules within a given module of a PyTorch model.
237
+
238
+ This function allows you to load parameters and buffers from a state dictionary into specific submodules
239
+ of a PyTorch model. If the `target_modules` argument is provided, only the specified submodules will be updated
240
+ with the corresponding entries from the state dictionary.
241
+
242
+ Args:
243
+ module (nn.Module): The PyTorch module containing the target submodules.
244
+ state_dict (StateDictType): The state dictionary containing parameters and buffers to load.
245
+ target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
246
+ If None, the entire module's state dictionary is updated if no special attribute is set
247
+ (look up the `_fusion_bench_target_modules` attribute).
248
+ strict (bool): Whether to strictly enforce that the keys in `state_dict` match the keys returned by
249
+ the module's `state_dict()` function. Default is True.
250
+ """
251
+ if target_modules is None:
252
+ if (
253
+ hasattr(module, "_fusion_bench_target_modules")
254
+ and module._fusion_bench_target_modules is not None
255
+ ):
256
+ return load_state_dict_into_target_modules(
257
+ module,
258
+ state_dict,
259
+ target_modules=module._fusion_bench_target_modules,
260
+ strict=strict,
261
+ assign=assign,
262
+ )
263
+ else:
264
+ return module.load_state_dict(state_dict, strict=strict, assign=assign)
265
+
266
+ if isinstance(target_modules, str):
267
+ target_modules = [target_modules]
268
+
269
+ assert (
270
+ len(target_modules) > 0
271
+ ), "target_modules should contain at least one module name."
272
+ results: list[_IncompatibleKeys] = []
273
+ for target_module in target_modules:
274
+ submodule_prefix = f"{target_module}."
275
+ submodule_prefix_len = len(submodule_prefix)
276
+ submodule = module.get_submodule(target_module)
277
+
278
+ # Extract the relevant portion of the state dictionary for the submodule
279
+ submodule_state_dict = {
280
+ key[submodule_prefix_len:]: value for key, value in state_dict.items()
281
+ }
282
+
283
+ # Load the extracted state dictionary into the submodule
284
+ result = submodule.load_state_dict(
285
+ submodule_state_dict, strict=strict, assign=assign
286
+ )
287
+ results.append(result)
288
+
289
+ # Merge results from all submodules
290
+ merged_result = _IncompatibleKeys(
291
+ missing_keys=[key for res in results for key in res.missing_keys],
292
+ unexpected_keys=[key for res in results for key in res.unexpected_keys],
293
+ )
294
+ return merged_result
@@ -0,0 +1,90 @@
1
+ """
2
+ This module contains a wrapper for switching between different models.
3
+
4
+ For example, it can be used to switch between different classification heads for a shared backbone.
5
+ """
6
+
7
+ import logging
8
+ from typing import Dict, Optional
9
+
10
+ from torch import nn
11
+
12
+ from fusion_bench.utils.misc import first, validate_and_suggest_corrections
13
+
14
+ __all__ = ["SwitchModule", "set_active_option"]
15
+
16
+ log = logging.getLogger(__name__)
17
+
18
+
19
+ def _standardize_option_name(name: str) -> str:
20
+ """
21
+ Standardizes the option name by:
22
+
23
+ - Stripping whitespace and converting to lowercase.
24
+ - Replacing `-` with `_` if needed.
25
+ - Replacing `/` with `_` if needed.
26
+
27
+ Args:
28
+ name (str): The option name to standardize.
29
+ """
30
+ name = name.strip().lower()
31
+ name = name.replace("-", "_")
32
+ name = name.replace("/", "_")
33
+ return name
34
+
35
+
36
+ class SwitchModule(nn.Module):
37
+ """
38
+ A wrapper module that contains multiple sub-modules (options) and allows switching between them.
39
+
40
+ This is useful for multi-head models or models where different parts are activated based on the task.
41
+ """
42
+
43
+ def __init__(self, modules: Dict[str, nn.Module]):
44
+ """
45
+ Args:
46
+ modules (Dict[str, nn.Module]): A dictionary of modules to switch between.
47
+ """
48
+ super().__init__()
49
+ standardized_modules = {
50
+ _standardize_option_name(name): module for name, module in modules.items()
51
+ }
52
+ self._option_modules = nn.ModuleDict(standardized_modules)
53
+ self._active_option = first(self._option_modules.keys())
54
+
55
+ def set_active_option(self, option_name: str):
56
+ standardized_name = _standardize_option_name(option_name)
57
+ validate_and_suggest_corrections(standardized_name, self._option_modules.keys())
58
+ self._active_option = standardized_name
59
+
60
+ def forward(self, *args, **kwargs):
61
+ active_module = self._option_modules[self._active_option]
62
+ return active_module(*args, **kwargs)
63
+
64
+ def __getattr__(self, name):
65
+ try:
66
+ return super().__getattr__(name)
67
+ except AttributeError:
68
+ active_module = self._option_modules[self._active_option]
69
+ if hasattr(active_module, name):
70
+ return getattr(active_module, name)
71
+ raise
72
+
73
+
74
+ def set_active_option(module: nn.Module, option_name: str) -> list[str]:
75
+ """
76
+ Utility function to set the active option for all SwitchModule instances within a given module.
77
+
78
+ Args:
79
+ module (nn.Module): The module to set the active option for.
80
+ option_name (str): The name of the option to activate.
81
+
82
+ Returns:
83
+ list[str]: A list of names of submodules that were activated.
84
+ """
85
+ activated_submodules = []
86
+ for name, submodule in module.named_modules():
87
+ if isinstance(submodule, SwitchModule):
88
+ submodule.set_active_option(option_name)
89
+ activated_submodules.append(name)
90
+ return activated_submodules
@@ -75,6 +75,12 @@ class BaseHydraProgram(BaseYAMLSerializable):
75
75
  - FusionBench CLI documentation for program execution details
76
76
  """
77
77
 
78
+ _program = None
79
+
80
+ def __init__(self, **kwargs):
81
+ super().__init__(**kwargs)
82
+ self._program = self
83
+
78
84
  @abstractmethod
79
85
  def run(self):
80
86
  """
@@ -267,6 +267,7 @@ class FabricModelFusionProgram(
267
267
  merged_model = self.method.run(self.modelpool)
268
268
  self.method.on_run_end()
269
269
 
270
+ report = None
270
271
  if merged_model is None:
271
272
  log.info(
272
273
  "No merged model returned by the method. Skipping saving and evaluation."
@@ -293,5 +294,8 @@ class FabricModelFusionProgram(
293
294
  )
294
295
  os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
295
296
  json.dump(report, open(self.report_save_path, "w"))
297
+ self.log_artifact(local_path=self.report_save_path)
296
298
  else:
297
299
  log.info("No task pool specified. Skipping evaluation.")
300
+
301
+ return {"merged_model": merged_model, "report": report}
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING
9
9
  import hydra
10
10
  from omegaconf import DictConfig, OmegaConf
11
11
 
12
- from fusion_bench.constants import PROJECT_ROOT_PATH
13
12
  from fusion_bench.utils import instantiate
14
13
  from fusion_bench.utils.hydra_utils import get_default_config_path
15
14
 
@@ -19,11 +18,6 @@ if TYPE_CHECKING:
19
18
  log = logging.getLogger(__name__)
20
19
 
21
20
 
22
- @hydra.main(
23
- config_path=get_default_config_path(),
24
- config_name="fabric_model_fusion",
25
- version_base=None,
26
- )
27
21
  def main(cfg: DictConfig) -> None:
28
22
  """
29
23
  Main entry point for the FusionBench command-line interface.
@@ -74,8 +68,25 @@ def main(cfg: DictConfig) -> None:
74
68
  err_msg += f"\n\nConfiguration content:\n{cfg}"
75
69
  raise TypeError(err_msg)
76
70
 
77
- program.run()
71
+ try:
72
+ program_result = program.run()
73
+ return program_result
74
+ except BaseException as e:
75
+ # Log the exception before exiting
76
+ if hasattr(program, "finalize") and callable(getattr(program, "finalize")):
77
+ program.finalize()
78
+ log.error(e, exc_info=True)
79
+ raise e
80
+
81
+
82
+ @hydra.main(
83
+ config_path=get_default_config_path(),
84
+ config_name="fabric_model_fusion",
85
+ version_base=None,
86
+ )
87
+ def _hydra_main(cfg: DictConfig) -> None:
88
+ main(cfg)
78
89
 
79
90
 
80
91
  if __name__ == "__main__":
81
- main()
92
+ _hydra_main()