fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +108 -3
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/scripts/cli.py +19 -8
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +18 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,12 @@
|
|
|
1
|
-
from typing import List, Mapping, Optional, Tuple
|
|
1
|
+
from typing import Iterator, List, Mapping, Optional, Tuple, Union
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
5
|
|
|
6
|
-
__all__ = "
|
|
6
|
+
__all__ = ["ParameterDictModel"]
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def
|
|
9
|
+
def set_nested_attr(
|
|
10
10
|
obj,
|
|
11
11
|
names: List[str],
|
|
12
12
|
val,
|
|
@@ -27,7 +27,7 @@ def _set_attr(
|
|
|
27
27
|
else:
|
|
28
28
|
if check_parent and not hasattr(obj, names[0]):
|
|
29
29
|
setattr(obj, names[0], parent_builder())
|
|
30
|
-
|
|
30
|
+
set_nested_attr(
|
|
31
31
|
getattr(obj, names[0]),
|
|
32
32
|
names[1:],
|
|
33
33
|
val,
|
|
@@ -36,7 +36,7 @@ def _set_attr(
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
|
|
39
|
-
def
|
|
39
|
+
def has_nested_attr(obj, names: List[str]):
|
|
40
40
|
"""
|
|
41
41
|
Checks if an attribute exists in an object recursively.
|
|
42
42
|
|
|
@@ -50,26 +50,49 @@ def has_attr(obj, names: List[str]):
|
|
|
50
50
|
if len(names) == 1:
|
|
51
51
|
return hasattr(obj, names[0])
|
|
52
52
|
else:
|
|
53
|
-
|
|
53
|
+
if not hasattr(obj, names[0]):
|
|
54
|
+
return False
|
|
55
|
+
return has_nested_attr(getattr(obj, names[0]), names[1:])
|
|
54
56
|
|
|
55
57
|
|
|
56
58
|
class ParameterDictModel(nn.Module):
|
|
57
59
|
"""
|
|
58
|
-
|
|
59
|
-
|
|
60
|
+
A module that stores parameters in a nested dictionary structure.
|
|
61
|
+
|
|
62
|
+
This model behaves similarly to `nn.ParameterDict`, but supports hierarchical keys
|
|
63
|
+
with dots (e.g., "layer1.weight"). Parameters are stored as nested attributes,
|
|
64
|
+
allowing for structured parameter access and manipulation.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
>>> params = {
|
|
68
|
+
... "encoder.weight": nn.Parameter(torch.randn(10, 5)),
|
|
69
|
+
... "decoder.bias": nn.Parameter(torch.randn(5)),
|
|
70
|
+
... }
|
|
71
|
+
>>> model = ParameterDictModel(params)
|
|
72
|
+
>>> model["encoder.weight"].shape
|
|
73
|
+
torch.Size([10, 5])
|
|
74
|
+
>>> "encoder.weight" in model
|
|
75
|
+
True
|
|
60
76
|
"""
|
|
61
77
|
|
|
62
78
|
def __init__(
|
|
63
79
|
self,
|
|
64
|
-
parameters: Optional[Mapping[str, nn.Parameter]] = None,
|
|
65
|
-
):
|
|
80
|
+
parameters: Optional[Mapping[str, Union[nn.Parameter, torch.Tensor]]] = None,
|
|
81
|
+
) -> None:
|
|
82
|
+
"""
|
|
83
|
+
Args:
|
|
84
|
+
parameters: Optional mapping of parameter names to parameter tensors.
|
|
85
|
+
Keys can contain dots to create nested structures.
|
|
86
|
+
Values must be `nn.Parameter` or `nn.Buffer` instances.
|
|
87
|
+
"""
|
|
88
|
+
|
|
66
89
|
super().__init__()
|
|
67
90
|
if parameters is not None:
|
|
68
91
|
for name, param in parameters.items():
|
|
69
92
|
assert isinstance(
|
|
70
93
|
param, (nn.Parameter, nn.Buffer)
|
|
71
94
|
), f"{name} is not a nn.Parameter or nn.Buffer"
|
|
72
|
-
|
|
95
|
+
set_nested_attr(
|
|
73
96
|
self,
|
|
74
97
|
name.split("."),
|
|
75
98
|
param,
|
|
@@ -77,12 +100,13 @@ class ParameterDictModel(nn.Module):
|
|
|
77
100
|
parent_builder=__class__,
|
|
78
101
|
)
|
|
79
102
|
|
|
80
|
-
def __repr__(self):
|
|
103
|
+
def __repr__(self) -> str:
|
|
81
104
|
"""
|
|
82
105
|
Generate a string representation of the model's parameters.
|
|
83
106
|
|
|
84
107
|
Returns:
|
|
85
|
-
|
|
108
|
+
A string representation of the model's parameters in the format:
|
|
109
|
+
"ParameterDictModel(name1: shape1, name2: shape2, ...)"
|
|
86
110
|
"""
|
|
87
111
|
param_reprs = []
|
|
88
112
|
for name, param in self.named_parameters():
|
|
@@ -90,32 +114,98 @@ class ParameterDictModel(nn.Module):
|
|
|
90
114
|
param_reprs.append(param_repr)
|
|
91
115
|
return f"{self.__class__.__name__}({', '.join(param_reprs)})"
|
|
92
116
|
|
|
93
|
-
def
|
|
94
|
-
|
|
117
|
+
def __iter__(self) -> Iterator[str]:
|
|
118
|
+
"""
|
|
119
|
+
Iterate over the model's parameters.
|
|
120
|
+
|
|
121
|
+
Yields:
|
|
122
|
+
Tuples of (parameter name, parameter tensor).
|
|
123
|
+
"""
|
|
124
|
+
yield from self.keys()
|
|
125
|
+
|
|
126
|
+
def __getitem__(
|
|
127
|
+
self, key: str
|
|
128
|
+
) -> Union[nn.Parameter, torch.Tensor, "ParameterDictModel"]:
|
|
129
|
+
"""
|
|
130
|
+
Retrieve a parameter or nested submodule by key.
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
key: Parameter name, which can contain dots for nested access.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
The parameter, tensor, or nested ParameterDictModel at the specified key.
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
KeyError: If the key is not found in the model.
|
|
140
|
+
"""
|
|
141
|
+
assert isinstance(
|
|
142
|
+
key, str
|
|
143
|
+
), f"Key must be a string, but got {type(key)}: {key}."
|
|
144
|
+
if not has_nested_attr(self, key.split(".")):
|
|
95
145
|
raise KeyError(f"Key {key} not found in {self}")
|
|
96
|
-
|
|
146
|
+
key_parts = key.split(".")
|
|
97
147
|
obj = self
|
|
98
|
-
for k in
|
|
148
|
+
for k in key_parts:
|
|
99
149
|
obj = getattr(obj, k)
|
|
100
150
|
return obj
|
|
101
151
|
|
|
102
|
-
def __setitem__(self, key: str, value: nn.Parameter):
|
|
103
|
-
|
|
104
|
-
|
|
152
|
+
def __setitem__(self, key: str, value: Union[nn.Parameter, torch.Tensor]) -> None:
|
|
153
|
+
"""
|
|
154
|
+
Set a parameter at the specified key, creating nested structure if needed.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
key: Parameter name, which can contain dots for nested assignment.
|
|
158
|
+
value: Parameter or tensor to assign.
|
|
159
|
+
"""
|
|
160
|
+
if not has_nested_attr(self, key.split(".")):
|
|
161
|
+
set_nested_attr(self, key.split("."), value, check_parent=True)
|
|
105
162
|
else:
|
|
106
|
-
|
|
163
|
+
set_nested_attr(self, key.split("."), value, check_parent=False)
|
|
164
|
+
|
|
165
|
+
def __contains__(self, key: str) -> bool:
|
|
166
|
+
"""
|
|
167
|
+
Check if a parameter key exists in the model.
|
|
107
168
|
|
|
108
|
-
|
|
109
|
-
|
|
169
|
+
Args:
|
|
170
|
+
key: Parameter name, which can contain dots for nested checking.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
True if the key exists, False otherwise.
|
|
174
|
+
"""
|
|
175
|
+
return has_nested_attr(self, key.split("."))
|
|
110
176
|
|
|
111
177
|
def keys(self):
|
|
112
|
-
|
|
178
|
+
"""
|
|
179
|
+
Return a list of all parameter names in the model.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
List of parameter names (including nested names with dots).
|
|
183
|
+
"""
|
|
184
|
+
return self.state_dict().keys()
|
|
185
|
+
|
|
186
|
+
def items(self):
|
|
187
|
+
"""
|
|
188
|
+
Return a list of (name, parameter) tuples.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
List of tuples containing parameter names and their corresponding tensors.
|
|
192
|
+
"""
|
|
193
|
+
yield from self.state_dict().items()
|
|
113
194
|
|
|
114
|
-
def
|
|
115
|
-
|
|
195
|
+
def values(self):
|
|
196
|
+
"""
|
|
197
|
+
Return a list of all parameter values in the model.
|
|
116
198
|
|
|
117
|
-
|
|
118
|
-
|
|
199
|
+
Returns:
|
|
200
|
+
List of parameter tensors.
|
|
201
|
+
"""
|
|
202
|
+
yield from self.state_dict().values()
|
|
119
203
|
|
|
120
|
-
def __len__(self):
|
|
204
|
+
def __len__(self) -> int:
|
|
205
|
+
"""
|
|
206
|
+
Return the number of parameters in the model.
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
The total number of parameters.
|
|
210
|
+
"""
|
|
121
211
|
return len(self.keys())
|
fusion_bench/models/utils.py
CHANGED
|
@@ -1,9 +1,37 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import Iterable, List, Optional
|
|
2
2
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
|
+
from torch.nn.modules.module import _IncompatibleKeys
|
|
5
6
|
|
|
6
|
-
from fusion_bench.utils.
|
|
7
|
+
from fusion_bench.utils.dict import dict_merge
|
|
8
|
+
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def is_leaf_module(module: nn.Module) -> bool:
|
|
12
|
+
return len(list(module.children())) == 0
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def named_leaf_modules(
|
|
16
|
+
module: nn.Module,
|
|
17
|
+
prefix: str = "",
|
|
18
|
+
ignore_empty: bool = True,
|
|
19
|
+
) -> Iterable[tuple[str, nn.Module]]:
|
|
20
|
+
"""
|
|
21
|
+
Recursively find the leaf modules in a module.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
module (nn.Module): PyTorch module.
|
|
25
|
+
prefix (str): A prefix to add to the layer names.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Iterable[tuple[str, nn.Module]]: An iterable of (name, module) tuples for each leaf module.
|
|
29
|
+
"""
|
|
30
|
+
for name, submodule in module.named_modules(prefix=prefix):
|
|
31
|
+
if is_leaf_module(submodule):
|
|
32
|
+
if ignore_empty and len(list(submodule.parameters())) == 0:
|
|
33
|
+
continue
|
|
34
|
+
yield name, submodule
|
|
7
35
|
|
|
8
36
|
|
|
9
37
|
def del_attr(obj, names: List[str]):
|
|
@@ -104,3 +132,163 @@ def disable_dropout(model: torch.nn.Module):
|
|
|
104
132
|
for module in model.modules():
|
|
105
133
|
if isinstance(module, torch.nn.Dropout):
|
|
106
134
|
module.p = 0
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_target_state_dict(
|
|
138
|
+
module: nn.Module,
|
|
139
|
+
target_modules: str | Iterable[str] | None = None,
|
|
140
|
+
prefix: str = "",
|
|
141
|
+
keep_vars: bool = False,
|
|
142
|
+
) -> StateDictType:
|
|
143
|
+
"""
|
|
144
|
+
This function retrieves the state dictionary of specified target submodules within a given module
|
|
145
|
+
of a PyTorch model or merged state dictionary from multiple submodules.
|
|
146
|
+
|
|
147
|
+
For example, if a model has submodules named "layer1", "layer2", and "layer3", and you want to get the state dictionary of "layer1" and "layer3",
|
|
148
|
+
you can call this function with `target_modules` set to `["layer1", "layer3"]`.
|
|
149
|
+
The function will return a state dictionary that includes only the parameters and buffers from those specified submodules.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
module (nn.Module): The PyTorch module containing the target submodules.
|
|
153
|
+
target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
|
|
154
|
+
If None, the entire module's state dictionary is returned if no special attribute is set (look up the `_fusion_bench_target_modules` attribute).
|
|
155
|
+
keep_vars (bool): If True, keeps the variables in the state dictionary. Default is False.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
StateDictType: The state dictionary of the specified target submodules, merged if multiple are provided.
|
|
159
|
+
"""
|
|
160
|
+
if target_modules is None:
|
|
161
|
+
if (
|
|
162
|
+
hasattr(module, "_fusion_bench_target_modules")
|
|
163
|
+
and module._fusion_bench_target_modules is not None
|
|
164
|
+
):
|
|
165
|
+
return get_target_state_dict(
|
|
166
|
+
module,
|
|
167
|
+
target_modules=module._fusion_bench_target_modules,
|
|
168
|
+
prefix=prefix,
|
|
169
|
+
keep_vars=keep_vars,
|
|
170
|
+
)
|
|
171
|
+
else:
|
|
172
|
+
return module.state_dict(prefix=prefix, keep_vars=keep_vars)
|
|
173
|
+
|
|
174
|
+
if isinstance(target_modules, str):
|
|
175
|
+
target_modules = [target_modules]
|
|
176
|
+
|
|
177
|
+
state_dicts = []
|
|
178
|
+
for target_module in target_modules:
|
|
179
|
+
submodule_prefix = (
|
|
180
|
+
f"{prefix}{target_module}." if prefix else f"{target_module}."
|
|
181
|
+
)
|
|
182
|
+
submodule = module.get_submodule(target_module)
|
|
183
|
+
state_dict = submodule.state_dict(prefix=submodule_prefix, keep_vars=keep_vars)
|
|
184
|
+
state_dicts.append(state_dict)
|
|
185
|
+
|
|
186
|
+
merged_state_dict = dict_merge(state_dicts, disjoint=True)
|
|
187
|
+
return merged_state_dict
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def validate_target_modules_equal(modules: Iterable[nn.Module]) -> None:
|
|
191
|
+
"""
|
|
192
|
+
Validates that the `_fusion_bench_target_modules` attribute is the same across all provided modules.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
modules (Iterable[nn.Module]): An iterable of PyTorch modules to validate.
|
|
196
|
+
|
|
197
|
+
Raises:
|
|
198
|
+
ValueError: If the `_fusion_bench_target_modules` attribute differs among the modules.
|
|
199
|
+
"""
|
|
200
|
+
model_iter = iter(modules)
|
|
201
|
+
first_module = next(model_iter)
|
|
202
|
+
|
|
203
|
+
if hasattr(first_module, "_fusion_bench_target_modules"):
|
|
204
|
+
target_modules = first_module._fusion_bench_target_modules
|
|
205
|
+
else:
|
|
206
|
+
# if the module does not have the attribute, set to None
|
|
207
|
+
target_modules = None
|
|
208
|
+
|
|
209
|
+
for module in model_iter:
|
|
210
|
+
if target_modules is None:
|
|
211
|
+
if (
|
|
212
|
+
hasattr(module, "_fusion_bench_target_modules")
|
|
213
|
+
and module._fusion_bench_target_modules != target_modules
|
|
214
|
+
):
|
|
215
|
+
raise ValueError(
|
|
216
|
+
"_fusion_bench_target_modules attribute differs among the provided modules."
|
|
217
|
+
)
|
|
218
|
+
else:
|
|
219
|
+
if (
|
|
220
|
+
not hasattr(module, "_fusion_bench_target_modules")
|
|
221
|
+
or module._fusion_bench_target_modules != target_modules
|
|
222
|
+
):
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"_fusion_bench_target_modules attribute differs among the provided modules."
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def load_state_dict_into_target_modules(
|
|
229
|
+
module: TorchModelType,
|
|
230
|
+
state_dict: StateDictType,
|
|
231
|
+
target_modules: str | Iterable[str] | None = None,
|
|
232
|
+
strict: bool = True,
|
|
233
|
+
assign: bool = False,
|
|
234
|
+
):
|
|
235
|
+
"""
|
|
236
|
+
Load a state dictionary into specified target submodules within a given module of a PyTorch model.
|
|
237
|
+
|
|
238
|
+
This function allows you to load parameters and buffers from a state dictionary into specific submodules
|
|
239
|
+
of a PyTorch model. If the `target_modules` argument is provided, only the specified submodules will be updated
|
|
240
|
+
with the corresponding entries from the state dictionary.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
module (nn.Module): The PyTorch module containing the target submodules.
|
|
244
|
+
state_dict (StateDictType): The state dictionary containing parameters and buffers to load.
|
|
245
|
+
target_modules (str | Iterable[str]): A single target module name or an iterable of target module names.
|
|
246
|
+
If None, the entire module's state dictionary is updated if no special attribute is set
|
|
247
|
+
(look up the `_fusion_bench_target_modules` attribute).
|
|
248
|
+
strict (bool): Whether to strictly enforce that the keys in `state_dict` match the keys returned by
|
|
249
|
+
the module's `state_dict()` function. Default is True.
|
|
250
|
+
"""
|
|
251
|
+
if target_modules is None:
|
|
252
|
+
if (
|
|
253
|
+
hasattr(module, "_fusion_bench_target_modules")
|
|
254
|
+
and module._fusion_bench_target_modules is not None
|
|
255
|
+
):
|
|
256
|
+
return load_state_dict_into_target_modules(
|
|
257
|
+
module,
|
|
258
|
+
state_dict,
|
|
259
|
+
target_modules=module._fusion_bench_target_modules,
|
|
260
|
+
strict=strict,
|
|
261
|
+
assign=assign,
|
|
262
|
+
)
|
|
263
|
+
else:
|
|
264
|
+
return module.load_state_dict(state_dict, strict=strict, assign=assign)
|
|
265
|
+
|
|
266
|
+
if isinstance(target_modules, str):
|
|
267
|
+
target_modules = [target_modules]
|
|
268
|
+
|
|
269
|
+
assert (
|
|
270
|
+
len(target_modules) > 0
|
|
271
|
+
), "target_modules should contain at least one module name."
|
|
272
|
+
results: list[_IncompatibleKeys] = []
|
|
273
|
+
for target_module in target_modules:
|
|
274
|
+
submodule_prefix = f"{target_module}."
|
|
275
|
+
submodule_prefix_len = len(submodule_prefix)
|
|
276
|
+
submodule = module.get_submodule(target_module)
|
|
277
|
+
|
|
278
|
+
# Extract the relevant portion of the state dictionary for the submodule
|
|
279
|
+
submodule_state_dict = {
|
|
280
|
+
key[submodule_prefix_len:]: value for key, value in state_dict.items()
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
# Load the extracted state dictionary into the submodule
|
|
284
|
+
result = submodule.load_state_dict(
|
|
285
|
+
submodule_state_dict, strict=strict, assign=assign
|
|
286
|
+
)
|
|
287
|
+
results.append(result)
|
|
288
|
+
|
|
289
|
+
# Merge results from all submodules
|
|
290
|
+
merged_result = _IncompatibleKeys(
|
|
291
|
+
missing_keys=[key for res in results for key in res.missing_keys],
|
|
292
|
+
unexpected_keys=[key for res in results for key in res.unexpected_keys],
|
|
293
|
+
)
|
|
294
|
+
return merged_result
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains a wrapper for switching between different models.
|
|
3
|
+
|
|
4
|
+
For example, it can be used to switch between different classification heads for a shared backbone.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Dict, Optional
|
|
9
|
+
|
|
10
|
+
from torch import nn
|
|
11
|
+
|
|
12
|
+
from fusion_bench.utils.misc import first, validate_and_suggest_corrections
|
|
13
|
+
|
|
14
|
+
__all__ = ["SwitchModule", "set_active_option"]
|
|
15
|
+
|
|
16
|
+
log = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _standardize_option_name(name: str) -> str:
|
|
20
|
+
"""
|
|
21
|
+
Standardizes the option name by:
|
|
22
|
+
|
|
23
|
+
- Stripping whitespace and converting to lowercase.
|
|
24
|
+
- Replacing `-` with `_` if needed.
|
|
25
|
+
- Replacing `/` with `_` if needed.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
name (str): The option name to standardize.
|
|
29
|
+
"""
|
|
30
|
+
name = name.strip().lower()
|
|
31
|
+
name = name.replace("-", "_")
|
|
32
|
+
name = name.replace("/", "_")
|
|
33
|
+
return name
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class SwitchModule(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
A wrapper module that contains multiple sub-modules (options) and allows switching between them.
|
|
39
|
+
|
|
40
|
+
This is useful for multi-head models or models where different parts are activated based on the task.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(self, modules: Dict[str, nn.Module]):
|
|
44
|
+
"""
|
|
45
|
+
Args:
|
|
46
|
+
modules (Dict[str, nn.Module]): A dictionary of modules to switch between.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
standardized_modules = {
|
|
50
|
+
_standardize_option_name(name): module for name, module in modules.items()
|
|
51
|
+
}
|
|
52
|
+
self._option_modules = nn.ModuleDict(standardized_modules)
|
|
53
|
+
self._active_option = first(self._option_modules.keys())
|
|
54
|
+
|
|
55
|
+
def set_active_option(self, option_name: str):
|
|
56
|
+
standardized_name = _standardize_option_name(option_name)
|
|
57
|
+
validate_and_suggest_corrections(standardized_name, self._option_modules.keys())
|
|
58
|
+
self._active_option = standardized_name
|
|
59
|
+
|
|
60
|
+
def forward(self, *args, **kwargs):
|
|
61
|
+
active_module = self._option_modules[self._active_option]
|
|
62
|
+
return active_module(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
def __getattr__(self, name):
|
|
65
|
+
try:
|
|
66
|
+
return super().__getattr__(name)
|
|
67
|
+
except AttributeError:
|
|
68
|
+
active_module = self._option_modules[self._active_option]
|
|
69
|
+
if hasattr(active_module, name):
|
|
70
|
+
return getattr(active_module, name)
|
|
71
|
+
raise
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def set_active_option(module: nn.Module, option_name: str) -> list[str]:
|
|
75
|
+
"""
|
|
76
|
+
Utility function to set the active option for all SwitchModule instances within a given module.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
module (nn.Module): The module to set the active option for.
|
|
80
|
+
option_name (str): The name of the option to activate.
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
list[str]: A list of names of submodules that were activated.
|
|
84
|
+
"""
|
|
85
|
+
activated_submodules = []
|
|
86
|
+
for name, submodule in module.named_modules():
|
|
87
|
+
if isinstance(submodule, SwitchModule):
|
|
88
|
+
submodule.set_active_option(option_name)
|
|
89
|
+
activated_submodules.append(name)
|
|
90
|
+
return activated_submodules
|
|
@@ -75,6 +75,12 @@ class BaseHydraProgram(BaseYAMLSerializable):
|
|
|
75
75
|
- FusionBench CLI documentation for program execution details
|
|
76
76
|
"""
|
|
77
77
|
|
|
78
|
+
_program = None
|
|
79
|
+
|
|
80
|
+
def __init__(self, **kwargs):
|
|
81
|
+
super().__init__(**kwargs)
|
|
82
|
+
self._program = self
|
|
83
|
+
|
|
78
84
|
@abstractmethod
|
|
79
85
|
def run(self):
|
|
80
86
|
"""
|
|
@@ -267,6 +267,7 @@ class FabricModelFusionProgram(
|
|
|
267
267
|
merged_model = self.method.run(self.modelpool)
|
|
268
268
|
self.method.on_run_end()
|
|
269
269
|
|
|
270
|
+
report = None
|
|
270
271
|
if merged_model is None:
|
|
271
272
|
log.info(
|
|
272
273
|
"No merged model returned by the method. Skipping saving and evaluation."
|
|
@@ -293,5 +294,8 @@ class FabricModelFusionProgram(
|
|
|
293
294
|
)
|
|
294
295
|
os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
|
|
295
296
|
json.dump(report, open(self.report_save_path, "w"))
|
|
297
|
+
self.log_artifact(local_path=self.report_save_path)
|
|
296
298
|
else:
|
|
297
299
|
log.info("No task pool specified. Skipping evaluation.")
|
|
300
|
+
|
|
301
|
+
return {"merged_model": merged_model, "report": report}
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -9,7 +9,6 @@ from typing import TYPE_CHECKING
|
|
|
9
9
|
import hydra
|
|
10
10
|
from omegaconf import DictConfig, OmegaConf
|
|
11
11
|
|
|
12
|
-
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
13
12
|
from fusion_bench.utils import instantiate
|
|
14
13
|
from fusion_bench.utils.hydra_utils import get_default_config_path
|
|
15
14
|
|
|
@@ -19,11 +18,6 @@ if TYPE_CHECKING:
|
|
|
19
18
|
log = logging.getLogger(__name__)
|
|
20
19
|
|
|
21
20
|
|
|
22
|
-
@hydra.main(
|
|
23
|
-
config_path=get_default_config_path(),
|
|
24
|
-
config_name="fabric_model_fusion",
|
|
25
|
-
version_base=None,
|
|
26
|
-
)
|
|
27
21
|
def main(cfg: DictConfig) -> None:
|
|
28
22
|
"""
|
|
29
23
|
Main entry point for the FusionBench command-line interface.
|
|
@@ -74,8 +68,25 @@ def main(cfg: DictConfig) -> None:
|
|
|
74
68
|
err_msg += f"\n\nConfiguration content:\n{cfg}"
|
|
75
69
|
raise TypeError(err_msg)
|
|
76
70
|
|
|
77
|
-
|
|
71
|
+
try:
|
|
72
|
+
program_result = program.run()
|
|
73
|
+
return program_result
|
|
74
|
+
except BaseException as e:
|
|
75
|
+
# Log the exception before exiting
|
|
76
|
+
if hasattr(program, "finalize") and callable(getattr(program, "finalize")):
|
|
77
|
+
program.finalize()
|
|
78
|
+
log.error(e, exc_info=True)
|
|
79
|
+
raise e
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
@hydra.main(
|
|
83
|
+
config_path=get_default_config_path(),
|
|
84
|
+
config_name="fabric_model_fusion",
|
|
85
|
+
version_base=None,
|
|
86
|
+
)
|
|
87
|
+
def _hydra_main(cfg: DictConfig) -> None:
|
|
88
|
+
main(cfg)
|
|
78
89
|
|
|
79
90
|
|
|
80
91
|
if __name__ == "__main__":
|
|
81
|
-
|
|
92
|
+
_hydra_main()
|