fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/mixins/lightning_fabric.py +110 -11
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -0
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/open_clip/modeling.py +61 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +25 -23
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +20 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +42 -19
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
fusion_bench/models/utils.py
CHANGED
|
@@ -1,9 +1,37 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import Iterable, List, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
|
+
from torch.nn.modules.module import _IncompatibleKeys
|
|
5
6
|
|
|
6
|
-
from fusion_bench.utils.
|
|
7
|
+
from fusion_bench.utils.dict import dict_merge
|
|
8
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def is_leaf_module(module: nn.Module) -> bool:
|
|
12
|
+
return len(list(module.children())) == 0
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def named_leaf_modules(
|
|
16
|
+
module: nn.Module,
|
|
17
|
+
prefix: str = "",
|
|
18
|
+
ignore_empty: bool = True,
|
|
19
|
+
) -> Iterable[tuple[str, nn.Module]]:
|
|
20
|
+
"""
|
|
21
|
+
Recursively find the leaf modules in a module.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
module (nn.Module): PyTorch module.
|
|
25
|
+
prefix (str): A prefix to add to the layer names.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Iterable[tuple[str, nn.Module]]: An iterable of (name, module) tuples for each leaf module.
|
|
29
|
+
"""
|
|
30
|
+
for name, submodule in module.named_modules(prefix=prefix):
|
|
31
|
+
if is_leaf_module(submodule):
|
|
32
|
+
if ignore_empty and len(list(submodule.parameters())) == 0:
|
|
33
|
+
continue
|
|
34
|
+
yield name, submodule
|
|
7
35
|
|
|
8
36
|
|
|
9
37
|
def del_attr(obj, names: List[str]):
|
|
@@ -104,3 +132,163 @@ def disable_dropout(model: torch.nn.Module):
|
|
|
104
132
|
for module in model.modules():
|
|
105
133
|
if isinstance(module, torch.nn.Dropout):
|
|
106
134
|
module.p = 0
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_target_state_dict(
|
|
138
|
+
module: nn.Module,
|
|
139
|
+
target_modules: str | Iterable[str] | None = None,
|
|
140
|
+
prefix: str = "",
|
|
141
|
+
keep_vars: bool = False,
|
|
142
|
+
) -> StateDictType:
|
|
143
|
+
"""
|
|
144
|
+
This function retrieves the state dictionary of specified target submodules within a given module
|
|
145
|
+
of a PyTorch model or merged state dictionary from multiple submodules.
|
|
146
|
+
|
|
147
|
+
For example, if a model has submodules named "layer1", "layer2", and "layer3", and you want to get the state dictionary of "layer1" and "layer3",
|
|
148
|
+
you can call this function with `target_modules` set to `["layer1", "layer3"]`.
|
|
149
|
+
The function will return a state dictionary that includes only the parameters and buffers from those specified submodules.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
module (nn.Module): The PyTorch module containing the target submodules.
|
|
153
|
+
target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
|
|
154
|
+
If None, the entire module's state dictionary is returned if no special attribute is set (look up the `_fusion_bench_target_modules` attribute).
|
|
155
|
+
keep_vars (bool): If True, keeps the variables in the state dictionary. Default is False.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
StateDictType: The state dictionary of the specified target submodules, merged if multiple are provided.
|
|
159
|
+
"""
|
|
160
|
+
if target_modules is None:
|
|
161
|
+
if (
|
|
162
|
+
hasattr(module, "_fusion_bench_target_modules")
|
|
163
|
+
and module._fusion_bench_target_modules is not None
|
|
164
|
+
):
|
|
165
|
+
return get_target_state_dict(
|
|
166
|
+
module,
|
|
167
|
+
target_modules=module._fusion_bench_target_modules,
|
|
168
|
+
prefix=prefix,
|
|
169
|
+
keep_vars=keep_vars,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
return module.state_dict(prefix=prefix, keep_vars=keep_vars)
|
|
173
|
+
|
|
174
|
+
if isinstance(target_modules, str):
|
|
175
|
+
target_modules = [target_modules]
|
|
176
|
+
|
|
177
|
+
state_dicts = []
|
|
178
|
+
for target_module in target_modules:
|
|
179
|
+
submodule_prefix = (
|
|
180
|
+
f"{prefix}{target_module}." if prefix else f"{target_module}."
|
|
181
|
+
)
|
|
182
|
+
submodule = module.get_submodule(target_module)
|
|
183
|
+
state_dict = submodule.state_dict(prefix=submodule_prefix, keep_vars=keep_vars)
|
|
184
|
+
state_dicts.append(state_dict)
|
|
185
|
+
|
|
186
|
+
merged_state_dict = dict_merge(state_dicts, disjoint=True)
|
|
187
|
+
return merged_state_dict
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def validate_target_modules_equal(modules: Iterable[nn.Module]) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Validates that the `_fusion_bench_target_modules` attribute is the same across all provided modules.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
modules (Iterable[nn.Module]): An iterable of PyTorch modules to validate.
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
ValueError: If the `_fusion_bench_target_modules` attribute differs among the modules.
|
|
199
|
+
"""
|
|
200
|
+
model_iter = iter(modules)
|
|
201
|
+
first_module = next(model_iter)
|
|
202
|
+
|
|
203
|
+
if hasattr(first_module, "_fusion_bench_target_modules"):
|
|
204
|
+
target_modules = first_module._fusion_bench_target_modules
|
|
205
|
+
else:
|
|
206
|
+
# if the module does not have the attribute, set to None
|
|
207
|
+
target_modules = None
|
|
208
|
+
|
|
209
|
+
for module in model_iter:
|
|
210
|
+
if target_modules is None:
|
|
211
|
+
if (
|
|
212
|
+
hasattr(module, "_fusion_bench_target_modules")
|
|
213
|
+
and module._fusion_bench_target_modules != target_modules
|
|
214
|
+
):
|
|
215
|
+
raise ValueError(
|
|
216
|
+
"_fusion_bench_target_modules attribute differs among the provided modules."
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
if (
|
|
220
|
+
not hasattr(module, "_fusion_bench_target_modules")
|
|
221
|
+
or module._fusion_bench_target_modules != target_modules
|
|
222
|
+
):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"_fusion_bench_target_modules attribute differs among the provided modules."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def load_state_dict_into_target_modules(
|
|
229
|
+
module: TorchModelType,
|
|
230
|
+
state_dict: StateDictType,
|
|
231
|
+
target_modules: str | Iterable[str] | None = None,
|
|
232
|
+
strict: bool = True,
|
|
233
|
+
assign: bool = False,
|
|
234
|
+
):
|
|
235
|
+
"""
|
|
236
|
+
Load a state dictionary into specified target submodules within a given module of a PyTorch model.
|
|
237
|
+
|
|
238
|
+
This function allows you to load parameters and buffers from a state dictionary into specific submodules
|
|
239
|
+
of a PyTorch model. If the `target_modules` argument is provided, only the specified submodules will be updated
|
|
240
|
+
with the corresponding entries from the state dictionary.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
module (nn.Module): The PyTorch module containing the target submodules.
|
|
244
|
+
state_dict (StateDictType): The state dictionary containing parameters and buffers to load.
|
|
245
|
+
target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
|
|
246
|
+
If None, the entire module's state dictionary is updated if no special attribute is set
|
|
247
|
+
(look up the `_fusion_bench_target_modules` attribute).
|
|
248
|
+
strict (bool): Whether to strictly enforce that the keys in `state_dict` match the keys returned by
|
|
249
|
+
the module's `state_dict()` function. Default is True.
|
|
250
|
+
"""
|
|
251
|
+
if target_modules is None:
|
|
252
|
+
if (
|
|
253
|
+
hasattr(module, "_fusion_bench_target_modules")
|
|
254
|
+
and module._fusion_bench_target_modules is not None
|
|
255
|
+
):
|
|
256
|
+
return load_state_dict_into_target_modules(
|
|
257
|
+
module,
|
|
258
|
+
state_dict,
|
|
259
|
+
target_modules=module._fusion_bench_target_modules,
|
|
260
|
+
strict=strict,
|
|
261
|
+
assign=assign,
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
return module.load_state_dict(state_dict, strict=strict, assign=assign)
|
|
265
|
+
|
|
266
|
+
if isinstance(target_modules, str):
|
|
267
|
+
target_modules = [target_modules]
|
|
268
|
+
|
|
269
|
+
assert (
|
|
270
|
+
len(target_modules) > 0
|
|
271
|
+
), "target_modules should contain at least one module name."
|
|
272
|
+
results: list[_IncompatibleKeys] = []
|
|
273
|
+
for target_module in target_modules:
|
|
274
|
+
submodule_prefix = f"{target_module}."
|
|
275
|
+
submodule_prefix_len = len(submodule_prefix)
|
|
276
|
+
submodule = module.get_submodule(target_module)
|
|
277
|
+
|
|
278
|
+
# Extract the relevant portion of the state dictionary for the submodule
|
|
279
|
+
submodule_state_dict = {
|
|
280
|
+
key[submodule_prefix_len:]: value for key, value in state_dict.items()
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
# Load the extracted state dictionary into the submodule
|
|
284
|
+
result = submodule.load_state_dict(
|
|
285
|
+
submodule_state_dict, strict=strict, assign=assign
|
|
286
|
+
)
|
|
287
|
+
results.append(result)
|
|
288
|
+
|
|
289
|
+
# Merge results from all submodules
|
|
290
|
+
merged_result = _IncompatibleKeys(
|
|
291
|
+
missing_keys=[key for res in results for key in res.missing_keys],
|
|
292
|
+
unexpected_keys=[key for res in results for key in res.unexpected_keys],
|
|
293
|
+
)
|
|
294
|
+
return merged_result
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains a wrapper for switching between different models.
|
|
3
|
+
|
|
4
|
+
For example, it can be used to switch between different classification heads for a shared backbone.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Dict, Optional
|
|
9
|
+
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from fusion_bench.utils.misc import first, validate_and_suggest_corrections
|
|
13
|
+
|
|
14
|
+
__all__ = ["SwitchModule", "set_active_option"]
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _standardize_option_name(name: str) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Standardizes the option name by:
|
|
22
|
+
|
|
23
|
+
- Stripping whitespace and converting to lowercase.
|
|
24
|
+
- Replacing `-` with `_` if needed.
|
|
25
|
+
- Replacing `/` with `_` if needed.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
name (str): The option name to standardize.
|
|
29
|
+
"""
|
|
30
|
+
name = name.strip().lower()
|
|
31
|
+
name = name.replace("-", "_")
|
|
32
|
+
name = name.replace("/", "_")
|
|
33
|
+
return name
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SwitchModule(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
A wrapper module that contains multiple sub-modules (options) and allows switching between them.
|
|
39
|
+
|
|
40
|
+
This is useful for multi-head models or models where different parts are activated based on the task.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, modules: Dict[str, nn.Module]):
|
|
44
|
+
"""
|
|
45
|
+
Args:
|
|
46
|
+
modules (Dict[str, nn.Module]): A dictionary of modules to switch between.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
standardized_modules = {
|
|
50
|
+
_standardize_option_name(name): module for name, module in modules.items()
|
|
51
|
+
}
|
|
52
|
+
self._option_modules = nn.ModuleDict(standardized_modules)
|
|
53
|
+
self._active_option = first(self._option_modules.keys())
|
|
54
|
+
|
|
55
|
+
def set_active_option(self, option_name: str):
|
|
56
|
+
standardized_name = _standardize_option_name(option_name)
|
|
57
|
+
validate_and_suggest_corrections(standardized_name, self._option_modules.keys())
|
|
58
|
+
self._active_option = standardized_name
|
|
59
|
+
|
|
60
|
+
def forward(self, *args, **kwargs):
|
|
61
|
+
active_module = self._option_modules[self._active_option]
|
|
62
|
+
return active_module(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
def __getattr__(self, name):
|
|
65
|
+
try:
|
|
66
|
+
return super().__getattr__(name)
|
|
67
|
+
except AttributeError:
|
|
68
|
+
active_module = self._option_modules[self._active_option]
|
|
69
|
+
if hasattr(active_module, name):
|
|
70
|
+
return getattr(active_module, name)
|
|
71
|
+
raise
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def set_active_option(module: nn.Module, option_name: str) -> list[str]:
|
|
75
|
+
"""
|
|
76
|
+
Utility function to set the active option for all SwitchModule instances within a given module.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
module (nn.Module): The module to set the active option for.
|
|
80
|
+
option_name (str): The name of the option to activate.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
list[str]: A list of names of submodules that were activated.
|
|
84
|
+
"""
|
|
85
|
+
activated_submodules = []
|
|
86
|
+
for name, submodule in module.named_modules():
|
|
87
|
+
if isinstance(submodule, SwitchModule):
|
|
88
|
+
submodule.set_active_option(option_name)
|
|
89
|
+
activated_submodules.append(name)
|
|
90
|
+
return activated_submodules
|
|
@@ -75,6 +75,12 @@ class BaseHydraProgram(BaseYAMLSerializable):
|
|
|
75
75
|
- FusionBench CLI documentation for program execution details
|
|
76
76
|
"""
|
|
77
77
|
|
|
78
|
+
_program = None
|
|
79
|
+
|
|
80
|
+
def __init__(self, **kwargs):
|
|
81
|
+
super().__init__(**kwargs)
|
|
82
|
+
self._program = self
|
|
83
|
+
|
|
78
84
|
@abstractmethod
|
|
79
85
|
def run(self):
|
|
80
86
|
"""
|
|
@@ -267,6 +267,7 @@ class FabricModelFusionProgram(
|
|
|
267
267
|
merged_model = self.method.run(self.modelpool)
|
|
268
268
|
self.method.on_run_end()
|
|
269
269
|
|
|
270
|
+
report = None
|
|
270
271
|
if merged_model is None:
|
|
271
272
|
log.info(
|
|
272
273
|
"No merged model returned by the method. Skipping saving and evaluation."
|
|
@@ -293,5 +294,8 @@ class FabricModelFusionProgram(
|
|
|
293
294
|
)
|
|
294
295
|
os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
|
|
295
296
|
json.dump(report, open(self.report_save_path, "w"))
|
|
297
|
+
self.log_artifact(local_path=self.report_save_path)
|
|
296
298
|
else:
|
|
297
299
|
log.info("No task pool specified. Skipping evaluation.")
|
|
300
|
+
|
|
301
|
+
return {"merged_model": merged_model, "report": report}
|
fusion_bench/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -3,36 +3,21 @@
|
|
|
3
3
|
This is the CLI script that is executed when the user runs the `fusion_bench` command.
|
|
4
4
|
The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
|
|
5
5
|
"""
|
|
6
|
-
|
|
7
|
-
import importlib
|
|
8
|
-
import importlib.resources
|
|
9
6
|
import logging
|
|
10
|
-
import
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
11
8
|
|
|
12
9
|
import hydra
|
|
13
10
|
from omegaconf import DictConfig, OmegaConf
|
|
14
11
|
|
|
15
|
-
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
16
|
-
from fusion_bench.programs import BaseHydraProgram
|
|
17
12
|
from fusion_bench.utils import instantiate
|
|
13
|
+
from fusion_bench.utils.hydra_utils import get_default_config_path
|
|
18
14
|
|
|
19
|
-
|
|
20
|
-
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from fusion_bench.programs import BaseHydraProgram
|
|
21
17
|
|
|
22
|
-
|
|
23
|
-
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
24
|
-
for config_dir in ["config", "fusion_bench_config"]:
|
|
25
|
-
config_path = os.path.join(config_path_root, config_dir)
|
|
26
|
-
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
27
|
-
return os.path.abspath(config_path)
|
|
28
|
-
return None
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
29
19
|
|
|
30
20
|
|
|
31
|
-
@hydra.main(
|
|
32
|
-
config_path=_get_default_config_path(),
|
|
33
|
-
config_name="fabric_model_fusion",
|
|
34
|
-
version_base=None,
|
|
35
|
-
)
|
|
36
21
|
def main(cfg: DictConfig) -> None:
|
|
37
22
|
"""
|
|
38
23
|
Main entry point for the FusionBench command-line interface.
|
|
@@ -68,7 +53,7 @@ def main(cfg: DictConfig) -> None:
|
|
|
68
53
|
loading the corresponding configuration files to populate the cfg parameter.
|
|
69
54
|
"""
|
|
70
55
|
OmegaConf.resolve(cfg)
|
|
71
|
-
program: BaseHydraProgram = instantiate(cfg)
|
|
56
|
+
program: "BaseHydraProgram" = instantiate(cfg)
|
|
72
57
|
|
|
73
58
|
# Validate that instantiation succeeded and returned an object with 'run' method
|
|
74
59
|
if not hasattr(program, "run") or not callable(getattr(program, "run")):
|
|
@@ -83,8 +68,25 @@ def main(cfg: DictConfig) -> None:
|
|
|
83
68
|
err_msg += f"\n\nConfiguration content:\n{cfg}"
|
|
84
69
|
raise TypeError(err_msg)
|
|
85
70
|
|
|
86
|
-
|
|
71
|
+
try:
|
|
72
|
+
program_result = program.run()
|
|
73
|
+
return program_result
|
|
74
|
+
except BaseException as e:
|
|
75
|
+
# Log the exception before exiting
|
|
76
|
+
if hasattr(program, "finalize") and callable(getattr(program, "finalize")):
|
|
77
|
+
program.finalize()
|
|
78
|
+
log.error(e, exc_info=True)
|
|
79
|
+
raise e
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@hydra.main(
|
|
83
|
+
config_path=get_default_config_path(),
|
|
84
|
+
config_name="fabric_model_fusion",
|
|
85
|
+
version_base=None,
|
|
86
|
+
)
|
|
87
|
+
def _hydra_main(cfg: DictConfig) -> None:
|
|
88
|
+
main(cfg)
|
|
87
89
|
|
|
88
90
|
|
|
89
91
|
if __name__ == "__main__":
|
|
90
|
-
|
|
92
|
+
_hydra_main()
|
fusion_bench/scripts/imgui.py
CHANGED
|
@@ -9,7 +9,7 @@ import hydra
|
|
|
9
9
|
from hydra import compose, initialize_config_dir
|
|
10
10
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
11
11
|
|
|
12
|
-
from fusion_bench.scripts.cli import
|
|
12
|
+
from fusion_bench.scripts.cli import get_default_config_path
|
|
13
13
|
|
|
14
14
|
# Keeping the ConfigGroupNode and AppState classes as they are
|
|
15
15
|
from fusion_bench.scripts.webui import AppState, ConfigGroupNode, priority_iterable
|
|
@@ -40,7 +40,7 @@ class App:
|
|
|
40
40
|
if self.args.config_path:
|
|
41
41
|
return Path(self.args.config_path)
|
|
42
42
|
else:
|
|
43
|
-
return
|
|
43
|
+
return get_default_config_path()
|
|
44
44
|
|
|
45
45
|
def generate_ui(self):
|
|
46
46
|
dpg.create_context()
|
fusion_bench/scripts/webui.py
CHANGED
|
@@ -16,7 +16,7 @@ from hydra import compose, initialize_config_dir
|
|
|
16
16
|
from hydra.core.hydra_config import HydraConfig
|
|
17
17
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
18
18
|
|
|
19
|
-
from fusion_bench.scripts.cli import
|
|
19
|
+
from fusion_bench.scripts.cli import get_default_config_path
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def escape_overrides(value: str) -> str:
|
|
@@ -385,7 +385,7 @@ class App:
|
|
|
385
385
|
if self.args.config_path:
|
|
386
386
|
return Path(self.args.config_path)
|
|
387
387
|
else:
|
|
388
|
-
return
|
|
388
|
+
return get_default_config_path()
|
|
389
389
|
|
|
390
390
|
def __getattr__(self, name):
|
|
391
391
|
"""
|