fusion-bench 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/method/__init__.py +9 -1
- fusion_bench/method/base_algorithm.py +29 -19
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/metrics/model_kinship/__init__.py +2 -0
- fusion_bench/metrics/model_kinship/calculate.py +77 -0
- fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
- fusion_bench/metrics/model_kinship/utility.py +184 -0
- fusion_bench/mixins/lightning_fabric.py +2 -8
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/masks/mask_model.py +8 -2
- fusion_bench/models/open_clip/modeling.py +68 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
- fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +21 -16
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/devices.py +3 -1
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/instantiate_utils.py +29 -18
- fusion_bench/utils/misc.py +16 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +165 -25
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/METADATA +7 -7
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/RECORD +41 -34
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.29.dist-info → fusion_bench-0.2.31.dist-info}/top_level.txt +0 -0
|
@@ -77,7 +77,16 @@ def torch_load_old(save_path: str, device=None):
|
|
|
77
77
|
return classifier
|
|
78
78
|
|
|
79
79
|
|
|
80
|
-
def torch_save(model, save_path, save_state_dict=True):
|
|
80
|
+
def torch_save(model: torch.nn.Module, save_path: str, save_state_dict: bool = True):
|
|
81
|
+
"""
|
|
82
|
+
Save a model to disk.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
model: The model to save.
|
|
86
|
+
save_path (str): The path to save the model to.
|
|
87
|
+
save_state_dict (bool): Whether to save the state dict of the model (weights only).
|
|
88
|
+
If False, the entire model object is saved. Default is True.
|
|
89
|
+
"""
|
|
81
90
|
# TODO: hacky way to save state dict
|
|
82
91
|
if save_state_dict and isinstance(model, torch.nn.Module):
|
|
83
92
|
model = model.state_dict()
|
|
@@ -86,7 +95,9 @@ def torch_save(model, save_path, save_state_dict=True):
|
|
|
86
95
|
torch.save(model, save_path)
|
|
87
96
|
|
|
88
97
|
|
|
89
|
-
def torch_load(
|
|
98
|
+
def torch_load(
|
|
99
|
+
save_path: str, device: Optional[torch.device] = None
|
|
100
|
+
) -> torch.nn.Module:
|
|
90
101
|
model = torch.load(save_path, map_location="cpu")
|
|
91
102
|
if device is not None:
|
|
92
103
|
model = model.to(device)
|
|
@@ -173,6 +173,24 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
173
173
|
|
|
174
174
|
@property
|
|
175
175
|
def forward_model(self):
|
|
176
|
+
"""
|
|
177
|
+
Get a functional model with merged parameters.
|
|
178
|
+
|
|
179
|
+
Returns a partial function that applies the pretrained model with the current
|
|
180
|
+
merged state dictionary. This allows for efficient forward passes without
|
|
181
|
+
modifying the original model's parameters.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Callable: A partial function that can be called with (args, kwargs) to
|
|
185
|
+
perform forward pass with merged parameters.
|
|
186
|
+
|
|
187
|
+
Example:
|
|
188
|
+
```python
|
|
189
|
+
# Internal usage during forward pass
|
|
190
|
+
forward_fn = merged_model.forward_model
|
|
191
|
+
output = forward_fn(args=(x,), kwargs={})
|
|
192
|
+
```
|
|
193
|
+
"""
|
|
176
194
|
return functools.partial(
|
|
177
195
|
functional_call,
|
|
178
196
|
self.pretrained_model,
|
|
@@ -181,10 +199,30 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
181
199
|
strict=self.strict,
|
|
182
200
|
)
|
|
183
201
|
|
|
184
|
-
def merge_and_unload(
|
|
202
|
+
def merge_and_unload(
|
|
203
|
+
self,
|
|
204
|
+
task_vector_mask: Optional[Dict[str, Tensor]] = None,
|
|
205
|
+
copy: bool = False,
|
|
206
|
+
) -> TorchModelType:
|
|
207
|
+
"""
|
|
208
|
+
Merge models and return the final merged model.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
|
|
212
|
+
for selective parameter merging. Defaults to None.
|
|
213
|
+
copy (bool, optional): Whether to return a deep copy of the pretrained model.
|
|
214
|
+
Defaults to False. If True, the original pretrained model remains unchanged.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
TorchModelType: The pretrained model with merged parameters loaded.
|
|
218
|
+
"""
|
|
185
219
|
self.merge_weights(task_vector_mask=task_vector_mask)
|
|
186
|
-
|
|
187
|
-
|
|
220
|
+
if copy:
|
|
221
|
+
model = deepcopy(self.pretrained_model)
|
|
222
|
+
else:
|
|
223
|
+
model = self.pretrained_model
|
|
224
|
+
model.load_state_dict(self._merged_state_dict)
|
|
225
|
+
return model
|
|
188
226
|
|
|
189
227
|
def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
190
228
|
"""
|
|
@@ -16,6 +16,7 @@ outputs = merged_model(inputs)
|
|
|
16
16
|
|
|
17
17
|
import functools
|
|
18
18
|
import logging
|
|
19
|
+
from copy import deepcopy
|
|
19
20
|
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
|
|
20
21
|
|
|
21
22
|
import torch
|
|
@@ -327,7 +328,11 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
327
328
|
self._merged_state_dict = state_dict
|
|
328
329
|
return state_dict
|
|
329
330
|
|
|
330
|
-
def merge_and_unload(
|
|
331
|
+
def merge_and_unload(
|
|
332
|
+
self,
|
|
333
|
+
task_vector_mask: Optional[Dict[str, Tensor]] = None,
|
|
334
|
+
copy: bool = False,
|
|
335
|
+
) -> TorchModelType:
|
|
331
336
|
"""
|
|
332
337
|
Merge models and return the final merged model.
|
|
333
338
|
|
|
@@ -338,6 +343,8 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
338
343
|
Args:
|
|
339
344
|
task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
|
|
340
345
|
for selective parameter merging. Defaults to None.
|
|
346
|
+
copy (bool, optional): Whether to return a deep copy of the pretrained model.
|
|
347
|
+
Defaults to False. If True, the original pretrained model remains unchanged.
|
|
341
348
|
|
|
342
349
|
Returns:
|
|
343
350
|
TorchModelType: The pretrained model with merged parameters loaded.
|
|
@@ -363,8 +370,12 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
363
370
|
The original pretrained model parameters will be lost.
|
|
364
371
|
"""
|
|
365
372
|
self.merge_weights(task_vector_mask=task_vector_mask)
|
|
366
|
-
|
|
367
|
-
|
|
373
|
+
if copy:
|
|
374
|
+
model = deepcopy(self.pretrained_model)
|
|
375
|
+
else:
|
|
376
|
+
model = self.pretrained_model
|
|
377
|
+
model.load_state_dict(self._merged_state_dict)
|
|
378
|
+
return model
|
|
368
379
|
|
|
369
380
|
def forward(self, *args, **kwargs):
|
|
370
381
|
"""
|
fusion_bench/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -3,33 +3,24 @@
|
|
|
3
3
|
This is the CLI script that is executed when the user runs the `fusion_bench` command.
|
|
4
4
|
The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
|
|
5
5
|
"""
|
|
6
|
-
|
|
7
|
-
import importlib
|
|
8
|
-
import importlib.resources
|
|
9
6
|
import logging
|
|
10
|
-
import
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
11
8
|
|
|
12
9
|
import hydra
|
|
13
10
|
from omegaconf import DictConfig, OmegaConf
|
|
14
11
|
|
|
15
12
|
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
16
|
-
from fusion_bench.programs import BaseHydraProgram
|
|
17
13
|
from fusion_bench.utils import instantiate
|
|
14
|
+
from fusion_bench.utils.hydra_utils import get_default_config_path
|
|
18
15
|
|
|
19
|
-
|
|
20
|
-
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from fusion_bench.programs import BaseHydraProgram
|
|
21
18
|
|
|
22
|
-
|
|
23
|
-
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
24
|
-
for config_dir in ["config", "fusion_bench_config"]:
|
|
25
|
-
config_path = os.path.join(config_path_root, config_dir)
|
|
26
|
-
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
27
|
-
return os.path.abspath(config_path)
|
|
28
|
-
return None
|
|
19
|
+
log = logging.getLogger(__name__)
|
|
29
20
|
|
|
30
21
|
|
|
31
22
|
@hydra.main(
|
|
32
|
-
config_path=
|
|
23
|
+
config_path=get_default_config_path(),
|
|
33
24
|
config_name="fabric_model_fusion",
|
|
34
25
|
version_base=None,
|
|
35
26
|
)
|
|
@@ -68,7 +59,21 @@ def main(cfg: DictConfig) -> None:
|
|
|
68
59
|
loading the corresponding configuration files to populate the cfg parameter.
|
|
69
60
|
"""
|
|
70
61
|
OmegaConf.resolve(cfg)
|
|
71
|
-
program: BaseHydraProgram = instantiate(cfg)
|
|
62
|
+
program: "BaseHydraProgram" = instantiate(cfg)
|
|
63
|
+
|
|
64
|
+
# Validate that instantiation succeeded and returned an object with 'run' method
|
|
65
|
+
if not hasattr(program, "run") or not callable(getattr(program, "run")):
|
|
66
|
+
err_msg = (
|
|
67
|
+
f"Expected an object with a callable 'run' method, but got {type(program).__name__}. "
|
|
68
|
+
"Ensure that the configuration specifies a concrete program class with '_target_'."
|
|
69
|
+
)
|
|
70
|
+
if "_target_" not in cfg:
|
|
71
|
+
err_msg += "\nThe '_target_' field is missing from the root configuration."
|
|
72
|
+
else:
|
|
73
|
+
err_msg += f"\nFound '_target_': {cfg._target_}"
|
|
74
|
+
err_msg += f"\n\nConfiguration content:\n{cfg}"
|
|
75
|
+
raise TypeError(err_msg)
|
|
76
|
+
|
|
72
77
|
program.run()
|
|
73
78
|
|
|
74
79
|
|
fusion_bench/scripts/imgui.py
CHANGED
|
@@ -9,7 +9,7 @@ import hydra
|
|
|
9
9
|
from hydra import compose, initialize_config_dir
|
|
10
10
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
11
11
|
|
|
12
|
-
from fusion_bench.scripts.cli import
|
|
12
|
+
from fusion_bench.scripts.cli import get_default_config_path
|
|
13
13
|
|
|
14
14
|
# Keeping the ConfigGroupNode and AppState classes as they are
|
|
15
15
|
from fusion_bench.scripts.webui import AppState, ConfigGroupNode, priority_iterable
|
|
@@ -40,7 +40,7 @@ class App:
|
|
|
40
40
|
if self.args.config_path:
|
|
41
41
|
return Path(self.args.config_path)
|
|
42
42
|
else:
|
|
43
|
-
return
|
|
43
|
+
return get_default_config_path()
|
|
44
44
|
|
|
45
45
|
def generate_ui(self):
|
|
46
46
|
dpg.create_context()
|
fusion_bench/scripts/webui.py
CHANGED
|
@@ -16,7 +16,7 @@ from hydra import compose, initialize_config_dir
|
|
|
16
16
|
from hydra.core.hydra_config import HydraConfig
|
|
17
17
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
|
18
18
|
|
|
19
|
-
from fusion_bench.scripts.cli import
|
|
19
|
+
from fusion_bench.scripts.cli import get_default_config_path
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def escape_overrides(value: str) -> str:
|
|
@@ -385,7 +385,7 @@ class App:
|
|
|
385
385
|
if self.args.config_path:
|
|
386
386
|
return Path(self.args.config_path)
|
|
387
387
|
else:
|
|
388
|
-
return
|
|
388
|
+
return get_default_config_path()
|
|
389
389
|
|
|
390
390
|
def __getattr__(self, name):
|
|
391
391
|
"""
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -53,6 +53,7 @@ _import_structure = {
|
|
|
53
53
|
"get_parameter_summary",
|
|
54
54
|
"human_readable",
|
|
55
55
|
"print_parameters",
|
|
56
|
+
"print_trainable_parameters",
|
|
56
57
|
"state_dict_to_vector",
|
|
57
58
|
"trainable_state_dict",
|
|
58
59
|
"vector_to_state_dict",
|
|
@@ -138,6 +139,7 @@ if TYPE_CHECKING:
|
|
|
138
139
|
get_parameter_summary,
|
|
139
140
|
human_readable,
|
|
140
141
|
print_parameters,
|
|
142
|
+
print_trainable_parameters,
|
|
141
143
|
state_dict_to_vector,
|
|
142
144
|
trainable_state_dict,
|
|
143
145
|
vector_to_state_dict,
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -32,11 +32,13 @@ def clear_cuda_cache():
|
|
|
32
32
|
Clears the CUDA memory cache to free up GPU memory.
|
|
33
33
|
Works only if CUDA is available.
|
|
34
34
|
"""
|
|
35
|
+
|
|
35
36
|
gc.collect()
|
|
36
37
|
if torch.cuda.is_available():
|
|
37
38
|
torch.cuda.empty_cache()
|
|
39
|
+
gc.collect()
|
|
38
40
|
else:
|
|
39
|
-
log.
|
|
41
|
+
log.debug("CUDA is not available. No cache to clear.")
|
|
40
42
|
|
|
41
43
|
|
|
42
44
|
def to_device(
|
|
@@ -1,4 +1,79 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
1
4
|
import hydra.core.hydra_config
|
|
5
|
+
from hydra import compose, initialize
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
|
|
8
|
+
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_default_config_path():
|
|
14
|
+
"""
|
|
15
|
+
Get the default configuration path by searching in common locations.
|
|
16
|
+
"""
|
|
17
|
+
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
18
|
+
for config_dir in ["config", "fusion_bench_config"]:
|
|
19
|
+
config_path = os.path.join(config_path_root, config_dir)
|
|
20
|
+
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
21
|
+
return os.path.abspath(config_path)
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def initialize_hydra_config(
|
|
26
|
+
config_name: str,
|
|
27
|
+
overrides: list[str] = None,
|
|
28
|
+
config_path: str = None,
|
|
29
|
+
return_hydra_config: bool = False,
|
|
30
|
+
) -> DictConfig:
|
|
31
|
+
"""
|
|
32
|
+
Load the Hydra configuration.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config_name (str): The name of the configuration file (without .yaml extension).
|
|
36
|
+
overrides (list[str]): A list of configuration overrides.
|
|
37
|
+
config_path (str): The path to the configuration directory. If None, it will be automatically detected.
|
|
38
|
+
return_hydra_config (bool): If True, return the Hydra configuration object.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
DictConfig: The loaded configuration.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> cfg = initialize_hydra_config(
|
|
45
|
+
... config_name="fabric_model_fusion",
|
|
46
|
+
... overrides=["method=dummy", "modelpool=dummy"],
|
|
47
|
+
... )
|
|
48
|
+
>>> print(cfg.method)
|
|
49
|
+
"""
|
|
50
|
+
if config_path is None:
|
|
51
|
+
config_path = get_default_config_path()
|
|
52
|
+
|
|
53
|
+
# check config_path validity
|
|
54
|
+
if config_path is None:
|
|
55
|
+
raise FileNotFoundError("Could not find configuration directory.")
|
|
56
|
+
if not os.path.isdir(config_path):
|
|
57
|
+
raise NotADirectoryError(
|
|
58
|
+
f"Configuration path {config_path} do not exists or is not a directory."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if overrides is None:
|
|
62
|
+
overrides = []
|
|
63
|
+
|
|
64
|
+
with initialize(
|
|
65
|
+
version_base=None,
|
|
66
|
+
config_path=os.path.relpath(
|
|
67
|
+
config_path,
|
|
68
|
+
start=os.path.dirname(__file__),
|
|
69
|
+
),
|
|
70
|
+
):
|
|
71
|
+
cfg = compose(
|
|
72
|
+
config_name=config_name,
|
|
73
|
+
overrides=overrides,
|
|
74
|
+
return_hydra_config=return_hydra_config,
|
|
75
|
+
)
|
|
76
|
+
return cfg
|
|
2
77
|
|
|
3
78
|
|
|
4
79
|
def get_hydra_output_dir():
|
|
@@ -14,8 +14,8 @@ from lightning_utilities.core.rank_zero import rank_zero_only
|
|
|
14
14
|
from omegaconf import DictConfig, OmegaConf, SCMode
|
|
15
15
|
from omegaconf._utils import is_structured_config
|
|
16
16
|
from rich import print
|
|
17
|
-
|
|
18
|
-
from
|
|
17
|
+
|
|
18
|
+
from fusion_bench.utils.rich_utils import print_bordered
|
|
19
19
|
|
|
20
20
|
PRINT_FUNCTION_CALL = True
|
|
21
21
|
"""
|
|
@@ -67,12 +67,22 @@ def _resolve_callable_name(f: Callable[..., Any]) -> str:
|
|
|
67
67
|
return full_name
|
|
68
68
|
|
|
69
69
|
|
|
70
|
-
def
|
|
70
|
+
def _get_obj_str(obj: Any) -> str:
|
|
71
|
+
if isinstance(obj, (str, int, float, bool, type(None))):
|
|
72
|
+
return repr(obj)
|
|
73
|
+
else:
|
|
74
|
+
return f"'<{type(obj).__name__} object>'"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _format_args_kwargs(args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> str:
|
|
71
78
|
result_strings = []
|
|
72
79
|
if len(args) > 0:
|
|
73
|
-
result_strings.append(", ".join(
|
|
80
|
+
result_strings.append(", ".join(_get_obj_str(arg) for arg in args))
|
|
81
|
+
|
|
74
82
|
if len(kwargs) > 0:
|
|
75
|
-
result_strings.append(
|
|
83
|
+
result_strings.append(
|
|
84
|
+
", ".join(f"{k}={_get_obj_str(v)}" for k, v in kwargs.items())
|
|
85
|
+
)
|
|
76
86
|
|
|
77
87
|
if len(result_strings) == 0:
|
|
78
88
|
return ""
|
|
@@ -145,14 +155,14 @@ def _call_target(
|
|
|
145
155
|
if _partial_:
|
|
146
156
|
if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
|
|
147
157
|
call_str = f"functools.partial({_resolve_callable_name(_target_)}, {_format_args_kwargs(args, kwargs)})"
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
158
|
+
print_bordered(
|
|
159
|
+
call_str,
|
|
160
|
+
code_style="python",
|
|
161
|
+
title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
|
|
162
|
+
style="cyan",
|
|
163
|
+
expand=False,
|
|
164
|
+
print_fn=PRINT_FUNCTION_CALL_FUNC,
|
|
154
165
|
)
|
|
155
|
-
|
|
156
166
|
if CATCH_EXCEPTION:
|
|
157
167
|
try:
|
|
158
168
|
return functools.partial(_target_, *args, **kwargs)
|
|
@@ -169,12 +179,13 @@ def _call_target(
|
|
|
169
179
|
else:
|
|
170
180
|
if PRINT_FUNCTION_CALL and getattr(rank_zero_only, "rank", 0) == 0:
|
|
171
181
|
call_str = f"{_resolve_callable_name(_target_)}({_format_args_kwargs(args, kwargs)})"
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
182
|
+
print_bordered(
|
|
183
|
+
call_str,
|
|
184
|
+
code_style="python",
|
|
185
|
+
title=f"Instantiate by calling {'function' if not isinstance(_target_, type) else 'class'}",
|
|
186
|
+
style="green",
|
|
187
|
+
expand=False,
|
|
188
|
+
print_fn=PRINT_FUNCTION_CALL_FUNC,
|
|
178
189
|
)
|
|
179
190
|
if CATCH_EXCEPTION:
|
|
180
191
|
try:
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -178,3 +178,19 @@ def validate_and_suggest_corrections(
|
|
|
178
178
|
if matches:
|
|
179
179
|
msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
|
|
180
180
|
raise ValueError(msg)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class DeprecationWarningMeta(type):
|
|
184
|
+
"""
|
|
185
|
+
Metaclass that issues a deprecation warning whenever a class using it is instantiated.
|
|
186
|
+
"""
|
|
187
|
+
|
|
188
|
+
def __call__(cls, *args, **kwargs):
|
|
189
|
+
import warnings
|
|
190
|
+
|
|
191
|
+
warnings.warn(
|
|
192
|
+
f"{cls.__name__} is deprecated and will be removed in a future version. ",
|
|
193
|
+
DeprecationWarning,
|
|
194
|
+
stacklevel=2,
|
|
195
|
+
)
|
|
196
|
+
return super(DeprecationWarningMeta, cls).__call__(*args, **kwargs)
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -10,6 +10,7 @@ from .type import StateDictType
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"count_parameters",
|
|
12
12
|
"print_parameters",
|
|
13
|
+
"print_trainable_parameters",
|
|
13
14
|
"check_parameters_all_equal",
|
|
14
15
|
"get_parameter_statistics",
|
|
15
16
|
"state_dict_to_vector",
|
|
@@ -282,6 +283,38 @@ def print_parameters(
|
|
|
282
283
|
)
|
|
283
284
|
|
|
284
285
|
|
|
286
|
+
def print_trainable_parameters(
|
|
287
|
+
module: nn.Module,
|
|
288
|
+
is_human_readable: bool = True,
|
|
289
|
+
print_fn=print,
|
|
290
|
+
non_zero_only: bool = False,
|
|
291
|
+
):
|
|
292
|
+
"""
|
|
293
|
+
Print the names and number of trainable parameters in a PyTorch model.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
module (nn.Module): The PyTorch model.
|
|
297
|
+
is_human_readable (bool, optional): Whether to print the number of parameters in a human-readable format. Defaults to True.
|
|
298
|
+
print_fn (callable, optional): The function to use for printing. Defaults to print.
|
|
299
|
+
non_zero_only (bool, optional): Whether to count only non-zero parameters. Defaults to False.
|
|
300
|
+
|
|
301
|
+
Prints:
|
|
302
|
+
The names and number of trainable parameters in the model.
|
|
303
|
+
|
|
304
|
+
```python
|
|
305
|
+
print_trainable_parameters(model)
|
|
306
|
+
# weight: 1.50M parameters
|
|
307
|
+
# bias: 500.00K parameters
|
|
308
|
+
```
|
|
309
|
+
"""
|
|
310
|
+
for name, param in module.named_parameters():
|
|
311
|
+
if param.requires_grad:
|
|
312
|
+
num_params = _numel(param, non_zero_only=non_zero_only)
|
|
313
|
+
if is_human_readable:
|
|
314
|
+
num_params = human_readable(num_params)
|
|
315
|
+
print_fn(f"{name}: {num_params} parameters")
|
|
316
|
+
|
|
317
|
+
|
|
285
318
|
def check_parameters_all_equal(
|
|
286
319
|
list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
|
|
287
320
|
) -> None:
|