fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +108 -3
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/scripts/cli.py +19 -8
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +18 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -3,11 +3,16 @@ from copy import deepcopy
|
|
|
3
3
|
from typing import Dict, List, Mapping, Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from torch import nn
|
|
6
|
+
from torch import Tensor, nn
|
|
7
7
|
|
|
8
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
9
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
10
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.models.utils import (
|
|
12
|
+
get_target_state_dict,
|
|
13
|
+
load_state_dict_into_target_modules,
|
|
14
|
+
validate_target_modules_equal,
|
|
15
|
+
)
|
|
11
16
|
from fusion_bench.utils import LazyStateDict
|
|
12
17
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
13
18
|
state_dict_add,
|
|
@@ -21,21 +26,22 @@ log = logging.getLogger(__name__)
|
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
def simple_average(
|
|
24
|
-
modules: List[Union[nn.Module, StateDictType]],
|
|
25
|
-
base_module: Optional[nn.Module] = None,
|
|
29
|
+
modules: List[Union[nn.Module, StateDictType, Tensor]],
|
|
30
|
+
base_module: Optional[Union[nn.Module, StateDictType, Tensor]] = None,
|
|
26
31
|
):
|
|
27
32
|
R"""
|
|
28
33
|
Averages the parameters of a list of PyTorch modules or state dictionaries.
|
|
29
34
|
|
|
30
35
|
This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
31
36
|
|
|
37
|
+
If `_fusion_bench_target_modules` attribute is set on the modules, only the parameters of the specified target submodules will be averaged.
|
|
38
|
+
|
|
32
39
|
Args:
|
|
33
|
-
modules (List[Union[nn.Module, StateDictType]]): A list of PyTorch modules or state dictionaries.
|
|
34
|
-
base_module (Optional[nn.Module]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
|
|
40
|
+
modules (List[Union[nn.Module, StateDictType, Tensor]]): A list of PyTorch modules or state dictionaries.
|
|
41
|
+
base_module (Optional[Union[nn.Module, StateDictType, Tensor]]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
|
|
35
42
|
|
|
36
43
|
Returns:
|
|
37
|
-
module_or_state_dict (Union[nn.Module, StateDictType]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
38
|
-
|
|
44
|
+
module_or_state_dict (Union[nn.Module, StateDictType, Tensor]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
39
45
|
Examples:
|
|
40
46
|
>>> import torch.nn as nn
|
|
41
47
|
>>> model1 = nn.Linear(10, 10)
|
|
@@ -47,23 +53,42 @@ def simple_average(
|
|
|
47
53
|
>>> averaged_state_dict = simple_average([state_dict1, state_dict2])
|
|
48
54
|
"""
|
|
49
55
|
assert len(modules) > 0, "modules must be a non-empty list"
|
|
56
|
+
validate_target_modules_equal(modules)
|
|
57
|
+
|
|
50
58
|
if isinstance(modules[0], nn.Module):
|
|
51
59
|
if base_module is None:
|
|
52
60
|
new_module = deepcopy(modules[0])
|
|
53
61
|
else:
|
|
54
62
|
new_module = base_module
|
|
55
|
-
state_dict = state_dict_avg(
|
|
56
|
-
|
|
63
|
+
state_dict = state_dict_avg(
|
|
64
|
+
[get_target_state_dict(module) for module in modules]
|
|
65
|
+
)
|
|
66
|
+
load_state_dict_into_target_modules(new_module, state_dict)
|
|
57
67
|
return new_module
|
|
58
68
|
elif isinstance(modules[0], Mapping):
|
|
59
|
-
|
|
69
|
+
# if the modules are state dicts
|
|
70
|
+
# compute the average state dict
|
|
71
|
+
avg_state_dict = state_dict_avg(modules)
|
|
72
|
+
# load into base_module if provided
|
|
73
|
+
if base_module is not None:
|
|
74
|
+
for k in avg_state_dict:
|
|
75
|
+
base_module[k] = avg_state_dict[k]
|
|
76
|
+
return base_module
|
|
77
|
+
else:
|
|
78
|
+
return avg_state_dict
|
|
79
|
+
elif isinstance(modules[0], Tensor):
|
|
80
|
+
mean_tensor = torch.stack(modules, dim=0).mean(dim=0)
|
|
81
|
+
if base_module is not None:
|
|
82
|
+
base_module.data = mean_tensor
|
|
83
|
+
return base_module
|
|
84
|
+
else:
|
|
85
|
+
return mean_tensor
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"Unsupported type: {type(modules[0])}")
|
|
60
88
|
|
|
61
89
|
|
|
62
90
|
@auto_register_config
|
|
63
|
-
class SimpleAverageAlgorithm(
|
|
64
|
-
SimpleProfilerMixin,
|
|
65
|
-
BaseAlgorithm,
|
|
66
|
-
):
|
|
91
|
+
class SimpleAverageAlgorithm(SimpleProfilerMixin, BaseAlgorithm):
|
|
67
92
|
def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
|
|
68
93
|
"""
|
|
69
94
|
Args:
|
|
@@ -87,13 +112,20 @@ class SimpleAverageAlgorithm(
|
|
|
87
112
|
Returns:
|
|
88
113
|
The fused model obtained by simple averaging.
|
|
89
114
|
"""
|
|
90
|
-
if isinstance(modelpool,
|
|
115
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
91
116
|
modelpool = BaseModelPool(modelpool)
|
|
92
117
|
|
|
93
118
|
log.info(
|
|
94
119
|
f"Fusing models using simple average on {len(modelpool.model_names)} models. "
|
|
95
120
|
f"models: {modelpool.model_names}"
|
|
96
121
|
)
|
|
122
|
+
if modelpool.has_instance_models and self.inplace:
|
|
123
|
+
log.warning(
|
|
124
|
+
"The model pool contains instance models, and inplace is set to True. "
|
|
125
|
+
"Therefore, the weights of the first model will be overwritten. "
|
|
126
|
+
"If this is desired behavior, this warning can be ignored."
|
|
127
|
+
)
|
|
128
|
+
|
|
97
129
|
sd: Optional[StateDictType] = None
|
|
98
130
|
forward_model = None
|
|
99
131
|
merged_model_names = []
|
|
@@ -106,12 +138,12 @@ class SimpleAverageAlgorithm(
|
|
|
106
138
|
with self.profile("merge weights"):
|
|
107
139
|
if sd is None:
|
|
108
140
|
# Initialize the state dictionary with the first model's state dictionary
|
|
109
|
-
sd = model
|
|
141
|
+
sd = get_target_state_dict(model)
|
|
110
142
|
forward_model = model if self.inplace else deepcopy(model)
|
|
111
143
|
else:
|
|
112
144
|
# Add the current model's state dictionary to the accumulated state dictionary
|
|
113
145
|
sd = state_dict_add(
|
|
114
|
-
sd, model
|
|
146
|
+
sd, get_target_state_dict(model), show_pbar=self.show_pbar
|
|
115
147
|
)
|
|
116
148
|
with self.profile("merge weights"):
|
|
117
149
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
@@ -124,11 +156,13 @@ class SimpleAverageAlgorithm(
|
|
|
124
156
|
forward_model = deepcopy(forward_model.meta_module).to_empty(
|
|
125
157
|
device=forward_model._device
|
|
126
158
|
)
|
|
127
|
-
|
|
159
|
+
|
|
160
|
+
result = load_state_dict_into_target_modules(forward_model, sd, strict=False)
|
|
128
161
|
if result.unexpected_keys:
|
|
129
162
|
raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
|
|
130
163
|
if result.missing_keys:
|
|
131
164
|
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
165
|
+
|
|
132
166
|
# print profile report and log the merged models
|
|
133
167
|
self.print_profile_summary()
|
|
134
168
|
log.info(f"merged {len(merged_model_names)} models:")
|
|
@@ -50,7 +50,7 @@ def task_arithmetic_merge(
|
|
|
50
50
|
finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
|
|
51
51
|
scaling_factor (float): A factor by which the task vectors will be scaled before merging.
|
|
52
52
|
inplace (bool, optional): If True, the pre-trained model will be modified in place.
|
|
53
|
-
|
|
53
|
+
If False, a copy of the pre-trained model will be modified. Defaults to True.
|
|
54
54
|
|
|
55
55
|
Returns:
|
|
56
56
|
nn.Module: The pre-trained model with the merged task vectors.
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import sys
|
|
4
5
|
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
|
|
5
6
|
|
|
6
7
|
import lightning as L
|
|
@@ -12,17 +13,32 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
12
13
|
|
|
13
14
|
from fusion_bench.constants import RuntimeConstants
|
|
14
15
|
from fusion_bench.utils import import_object
|
|
16
|
+
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
|
|
15
17
|
from fusion_bench.utils.instantiate_utils import instantiate
|
|
16
18
|
|
|
17
19
|
if TYPE_CHECKING:
|
|
18
20
|
import lightning.fabric.loggers.tensorboard
|
|
19
21
|
from lightning.fabric.strategies import FSDPStrategy
|
|
22
|
+
from lightning.pytorch.loggers import MLFlowLogger
|
|
23
|
+
from mlflow.tracking.client import MlflowClient
|
|
20
24
|
|
|
21
25
|
log = logging.getLogger(__name__)
|
|
22
26
|
|
|
23
27
|
TensorOrModule = TypeVar("TensorOrModule", torch.Tensor, torch.nn.Module, Any)
|
|
24
28
|
|
|
25
29
|
|
|
30
|
+
def _fabric_has_logger(fabric: L.Fabric) -> bool:
|
|
31
|
+
"""
|
|
32
|
+
Check if the fabric has a logger.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
fabric (L.Fabric): The Lightning Fabric instance.
|
|
36
|
+
Returns:
|
|
37
|
+
bool: True if the fabric has a logger, False otherwise.
|
|
38
|
+
"""
|
|
39
|
+
return fabric._loggers is not None and len(fabric._loggers) > 0
|
|
40
|
+
|
|
41
|
+
|
|
26
42
|
def get_policy(*args: str) -> set:
|
|
27
43
|
"""
|
|
28
44
|
Get the policy from the provided list of policy names.
|
|
@@ -43,6 +59,21 @@ def get_size_based_auto_wrap_policy(*args, **kwargs):
|
|
|
43
59
|
return policy
|
|
44
60
|
|
|
45
61
|
|
|
62
|
+
def _is_mlflow_logger(fabric: L.Fabric) -> bool:
|
|
63
|
+
"""
|
|
64
|
+
Check if the fabric's logger is an instance of MLFlowLogger.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
fabric (L.Fabric): The Lightning Fabric instance.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
bool: True if the logger is an instance of MLFlowLogger, False otherwise.
|
|
71
|
+
"""
|
|
72
|
+
if not _fabric_has_logger(fabric):
|
|
73
|
+
return False
|
|
74
|
+
return fabric.logger.__class__.__name__ == "MLFlowLogger"
|
|
75
|
+
|
|
76
|
+
|
|
46
77
|
class LightningFabricMixin:
|
|
47
78
|
"""
|
|
48
79
|
A mixin class for integrating Lightning Fabric into a project.
|
|
@@ -79,8 +110,8 @@ class LightningFabricMixin:
|
|
|
79
110
|
"""
|
|
80
111
|
if self._fabric_instance is None:
|
|
81
112
|
if config.get("fabric", None) is None:
|
|
82
|
-
log.warning("No fabric configuration found. use default settings.")
|
|
83
|
-
self._fabric_instance = L.Fabric()
|
|
113
|
+
log.warning("No fabric configuration found. use default settings. By default, use 1 device.")
|
|
114
|
+
self._fabric_instance = L.Fabric(devices=1)
|
|
84
115
|
else:
|
|
85
116
|
self._fabric_instance = instantiate(config.fabric)
|
|
86
117
|
if not _is_using_cli(): # if not using cli, launch the fabric
|
|
@@ -123,7 +154,10 @@ class LightningFabricMixin:
|
|
|
123
154
|
Retrieves the log directory from the fabric's logger.
|
|
124
155
|
"""
|
|
125
156
|
if self.fabric is not None and len(self.fabric._loggers) > 0:
|
|
126
|
-
|
|
157
|
+
if hasattr(self.fabric.logger, "log_dir"):
|
|
158
|
+
log_dir = self.fabric.logger.log_dir
|
|
159
|
+
else:
|
|
160
|
+
log_dir = None
|
|
127
161
|
|
|
128
162
|
# Special handling for SwanLabLogger to get the correct log directory
|
|
129
163
|
if (
|
|
@@ -132,6 +166,20 @@ class LightningFabricMixin:
|
|
|
132
166
|
):
|
|
133
167
|
log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
|
|
134
168
|
|
|
169
|
+
if (
|
|
170
|
+
log_dir is None
|
|
171
|
+
and self.fabric.logger.__class__.__name__ == "MLFlowLogger"
|
|
172
|
+
):
|
|
173
|
+
log_dir = self.fabric.logger.save_dir
|
|
174
|
+
if log_dir is None:
|
|
175
|
+
try:
|
|
176
|
+
log_dir = self._program.config.path.log_dir
|
|
177
|
+
except Exception:
|
|
178
|
+
log.error(
|
|
179
|
+
"Failed to get log_dir from program config for MLFlowLogger."
|
|
180
|
+
)
|
|
181
|
+
log_dir = "outputs"
|
|
182
|
+
|
|
135
183
|
assert log_dir is not None, "log_dir should not be None"
|
|
136
184
|
if self.fabric.is_global_zero and not os.path.exists(log_dir):
|
|
137
185
|
os.makedirs(log_dir, exist_ok=True)
|
|
@@ -246,3 +294,60 @@ class LightningFabricMixin:
|
|
|
246
294
|
"""
|
|
247
295
|
for i, param_group in enumerate(optimizer.param_groups):
|
|
248
296
|
self.fabric.log(name_template.format(i), param_group["lr"], step=step)
|
|
297
|
+
|
|
298
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None):
|
|
299
|
+
"""
|
|
300
|
+
Logs a file as an artifact to the fabric's logger.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
local_dir: The path to the directory to log as an artifact.
|
|
304
|
+
artifact_path: The directory within the logger's artifact storage to save the file.
|
|
305
|
+
"""
|
|
306
|
+
if _is_mlflow_logger(self.fabric):
|
|
307
|
+
logger: "MLFlowLogger" = self.fabric.logger
|
|
308
|
+
experiment: "MlflowClient" = logger.experiment
|
|
309
|
+
experiment.log_artifact(
|
|
310
|
+
logger.run_id,
|
|
311
|
+
local_path=local_path,
|
|
312
|
+
artifact_path=(artifact_path),
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def log_artifacts(self, local_dir: str, artifact_path: str | None = None):
|
|
316
|
+
"""
|
|
317
|
+
Logs a directory as artifacts to the fabric's logger.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
local_dir: The path to the directory to log as artifacts.
|
|
321
|
+
artifact_path: The directory within the logger's artifact storage to save the files.
|
|
322
|
+
"""
|
|
323
|
+
if _is_mlflow_logger(self.fabric):
|
|
324
|
+
logger: "MLFlowLogger" = self.fabric.logger
|
|
325
|
+
experiment: "MlflowClient" = logger.experiment
|
|
326
|
+
experiment.log_artifacts(
|
|
327
|
+
logger.run_id,
|
|
328
|
+
local_dir=local_dir,
|
|
329
|
+
artifact_path=artifact_path,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
def finalize(self):
|
|
333
|
+
"""
|
|
334
|
+
Destructor to ensure proper cleanup of the Lightning Fabric instance.
|
|
335
|
+
"""
|
|
336
|
+
if self._fabric_instance is None:
|
|
337
|
+
return
|
|
338
|
+
|
|
339
|
+
if _fabric_has_logger(self.fabric) and _is_mlflow_logger(self.fabric):
|
|
340
|
+
if sys.exc_info()[0] is None:
|
|
341
|
+
status = "success"
|
|
342
|
+
else:
|
|
343
|
+
status = "failed"
|
|
344
|
+
self.fabric.logger.finalize(status)
|
|
345
|
+
|
|
346
|
+
del self._fabric_instance
|
|
347
|
+
self._fabric_instance = None
|
|
348
|
+
|
|
349
|
+
def __del__(self):
|
|
350
|
+
"""
|
|
351
|
+
Destructor to ensure proper cleanup of the Lightning Fabric instance.
|
|
352
|
+
"""
|
|
353
|
+
self.finalize()
|
|
@@ -68,7 +68,7 @@ def auto_register_config(cls):
|
|
|
68
68
|
|
|
69
69
|
Behavior:
|
|
70
70
|
- **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
|
|
71
|
-
|
|
71
|
+
from the __init__ method are automatically added to _config_mapping
|
|
72
72
|
- **Positional Arguments**: Handled in order and mapped to corresponding parameter names
|
|
73
73
|
- **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
|
|
74
74
|
- **Default Values**: Applied when parameters are not provided via arguments
|
|
@@ -7,11 +7,12 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
-
from fusion_bench import TorchModelType
|
|
10
|
+
from fusion_bench import StateDictType, TorchModelType
|
|
11
11
|
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
12
12
|
from fusion_bench.utils import (
|
|
13
13
|
ValidationError,
|
|
14
14
|
instantiate,
|
|
15
|
+
state_dict_sub,
|
|
15
16
|
timeit_context,
|
|
16
17
|
validate_model_name,
|
|
17
18
|
)
|
|
@@ -57,6 +58,10 @@ class BaseModelPool(
|
|
|
57
58
|
**kwargs,
|
|
58
59
|
):
|
|
59
60
|
if isinstance(models, List):
|
|
61
|
+
log.debug(
|
|
62
|
+
"Initializing BaseModelPool with a list of models. "
|
|
63
|
+
"Converting to a dictionary with integer string keys."
|
|
64
|
+
)
|
|
60
65
|
models = {str(model_idx): model for model_idx, model in enumerate(models)}
|
|
61
66
|
|
|
62
67
|
if isinstance(models, dict):
|
|
@@ -81,6 +86,22 @@ class BaseModelPool(
|
|
|
81
86
|
self._test_datasets = test_datasets
|
|
82
87
|
super().__init__(**kwargs)
|
|
83
88
|
|
|
89
|
+
@property
|
|
90
|
+
def has_instance_models(self) -> bool:
|
|
91
|
+
"""
|
|
92
|
+
Check if the model pool contains any pre-instantiated models.
|
|
93
|
+
|
|
94
|
+
Attention:
|
|
95
|
+
Some algorithms may modify the models in-place if they are pre-instantiated.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
bool: True if there are pre-instantiated models, False otherwise.
|
|
99
|
+
"""
|
|
100
|
+
for model_cfg in self._models.values():
|
|
101
|
+
if isinstance(model_cfg, nn.Module):
|
|
102
|
+
return True
|
|
103
|
+
return False
|
|
104
|
+
|
|
84
105
|
@property
|
|
85
106
|
def has_pretrained(self) -> bool:
|
|
86
107
|
"""
|
|
@@ -329,6 +350,21 @@ class BaseModelPool(
|
|
|
329
350
|
for model_name in self.model_names:
|
|
330
351
|
yield model_name, self.load_model(model_name)
|
|
331
352
|
|
|
353
|
+
def load_pretrained_model_and_task_vectors(
|
|
354
|
+
self,
|
|
355
|
+
) -> Tuple[TorchModelType, List[StateDictType]]:
|
|
356
|
+
pretrained_model = self.load_pretrained_model()
|
|
357
|
+
|
|
358
|
+
task_vectors = []
|
|
359
|
+
for model_name in self.model_names:
|
|
360
|
+
finetuned_model = self.load_model(model_name)
|
|
361
|
+
task_vector = state_dict_sub(
|
|
362
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
363
|
+
)
|
|
364
|
+
task_vectors.append(task_vector)
|
|
365
|
+
|
|
366
|
+
return pretrained_model, task_vectors
|
|
367
|
+
|
|
332
368
|
@property
|
|
333
369
|
def has_train_dataset(self) -> bool:
|
|
334
370
|
"""
|
|
@@ -98,7 +98,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
|
|
|
98
98
|
- Load ConvNeXt models either from a pretrained checkpoint or from config.
|
|
99
99
|
- Optionally adapt the classifier head to match dataset classnames.
|
|
100
100
|
- Override `forward` to return logits for consistent interfaces within
|
|
101
|
-
|
|
101
|
+
FusionBench.
|
|
102
102
|
|
|
103
103
|
See `fusion_bench.modelpool.resnet_for_image_classification` for a closely
|
|
104
104
|
related ResNet-based pool with analogous behavior.
|
|
@@ -161,6 +161,9 @@ class ConvNextForImageClassificationPool(BaseModelPool):
|
|
|
161
161
|
).logits
|
|
162
162
|
model.original_forward = original_forward
|
|
163
163
|
|
|
164
|
+
# Mark ConvNeXt layers for FusionBench fusion
|
|
165
|
+
model._fusion_bench_target_modules = ["convnext"]
|
|
166
|
+
|
|
164
167
|
return model
|
|
165
168
|
|
|
166
169
|
@override
|
|
@@ -180,7 +183,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
|
|
|
180
183
|
- The ConvNeXt model via `model.save_pretrained`.
|
|
181
184
|
- The paired image processor via `AutoImageProcessor.save_pretrained`.
|
|
182
185
|
- If `algorithm_config` is provided and on rank-zero, a README model card
|
|
183
|
-
|
|
186
|
+
documenting the FusionBench configuration.
|
|
184
187
|
"""
|
|
185
188
|
model.save_pretrained(path)
|
|
186
189
|
self.load_processor().save_pretrained(path)
|
fusion_bench/models/hf_clip.py
CHANGED
|
@@ -62,16 +62,36 @@ class HFCLIPClassifier(nn.Module):
|
|
|
62
62
|
persistent=False,
|
|
63
63
|
)
|
|
64
64
|
|
|
65
|
+
# NOTE:
|
|
66
|
+
# The property setters seems not to work properly with `nn.Module` attributes.
|
|
67
|
+
# So avoid using them in practice.
|
|
68
|
+
# To set the text or vision model, directly access the attributes.
|
|
69
|
+
# For example:
|
|
70
|
+
# classifier.clip_model.text_model = new_text_model
|
|
71
|
+
# or
|
|
72
|
+
# classifier.clip_model.vision_model = new_vision_model
|
|
73
|
+
# reference: https://github.com/pytorch/pytorch/issues/52664
|
|
74
|
+
|
|
65
75
|
@property
|
|
66
76
|
def text_model(self):
|
|
67
77
|
"""Get the text model component of CLIP."""
|
|
68
78
|
return self.clip_model.text_model
|
|
69
79
|
|
|
80
|
+
@text_model.setter
|
|
81
|
+
def text_model(self, model: nn.Module):
|
|
82
|
+
"""Set the text model component of CLIP."""
|
|
83
|
+
self.clip_model.text_model = model
|
|
84
|
+
|
|
70
85
|
@property
|
|
71
86
|
def vision_model(self):
|
|
72
87
|
"""Get the vision model component of CLIP."""
|
|
73
88
|
return self.clip_model.vision_model
|
|
74
89
|
|
|
90
|
+
@vision_model.setter
|
|
91
|
+
def vision_model(self, model: nn.Module):
|
|
92
|
+
"""Set the vision model component of CLIP."""
|
|
93
|
+
self.clip_model.vision_model = model
|
|
94
|
+
|
|
75
95
|
def set_classification_task(
|
|
76
96
|
self,
|
|
77
97
|
classnames: List[str],
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .base import ModulatedModel, TaskModulator
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Dict, Generic, Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from fusion_bench import TorchModelType
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ModulatedModel(nn.Module, Generic[TorchModelType]):
|
|
14
|
+
"""
|
|
15
|
+
A model wrapper that uses task-specific modulators to adapt a shared backbone
|
|
16
|
+
for different tasks.
|
|
17
|
+
|
|
18
|
+
The model maintains a shared backbone and task-specific modulators. During forward pass,
|
|
19
|
+
the appropriate modulator is applied based on the current task.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
_current_task: Optional[str] = None
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
backbone: TorchModelType,
|
|
27
|
+
modulators: Dict[str, "TaskModulator[TorchModelType]"],
|
|
28
|
+
):
|
|
29
|
+
super().__init__()
|
|
30
|
+
self.backbone = backbone
|
|
31
|
+
self.modulators = nn.ModuleDict(modulators)
|
|
32
|
+
|
|
33
|
+
def add_modulator(self, task_name: str, modulator: "TaskModulator[TorchModelType]"):
|
|
34
|
+
"""Add a new task-specific modulator."""
|
|
35
|
+
if task_name in self.modulators:
|
|
36
|
+
raise ValueError(f"Modulator for task '{task_name}' already exists.")
|
|
37
|
+
self.modulators[task_name] = modulator
|
|
38
|
+
|
|
39
|
+
def remove_modulator(self, task_name: str):
|
|
40
|
+
"""Remove an existing task-specific modulator."""
|
|
41
|
+
if task_name not in self.modulators:
|
|
42
|
+
raise ValueError(f"Modulator for task '{task_name}' does not exist.")
|
|
43
|
+
if self._current_task == task_name:
|
|
44
|
+
log.warning(
|
|
45
|
+
f"Removing modulator for current task '{task_name}'. "
|
|
46
|
+
"This will make unset the current task unpredictable."
|
|
47
|
+
)
|
|
48
|
+
del self.modulators[task_name]
|
|
49
|
+
|
|
50
|
+
def set_task(self, task_name: str):
|
|
51
|
+
"""Set the current task for inference."""
|
|
52
|
+
if task_name not in self.modulators:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"Task '{task_name}' not found in modulators. Available tasks: {list(self.modulators.keys())}"
|
|
55
|
+
)
|
|
56
|
+
if self._current_task == task_name:
|
|
57
|
+
return
|
|
58
|
+
|
|
59
|
+
# unset previous task
|
|
60
|
+
if self._current_task is not None:
|
|
61
|
+
self.modulators[self._current_task].remove(self)
|
|
62
|
+
assert (
|
|
63
|
+
self._current_task is None
|
|
64
|
+
), "Current task should be None after removal."
|
|
65
|
+
|
|
66
|
+
# set new task
|
|
67
|
+
self.modulators[task_name].apply(self)
|
|
68
|
+
self._current_task = task_name
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def current_task(self) -> Optional[str]:
|
|
72
|
+
"""Get the current task name."""
|
|
73
|
+
return self._current_task
|
|
74
|
+
|
|
75
|
+
def forward(self, *args, **kwargs) -> Any:
|
|
76
|
+
"""
|
|
77
|
+
Forward pass with task-specific modulation.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
*args: Positional arguments for the backbone model
|
|
81
|
+
**kwargs: Keyword arguments for the backbone model
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Model output after applying task-specific modulation
|
|
85
|
+
"""
|
|
86
|
+
if self._current_task is None:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
"No task specified. Set current_task or provide 'task' argument."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
return self.backbone(*args, **kwargs)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class TaskModulator(nn.Module, Generic[TorchModelType], ABC):
|
|
95
|
+
"""
|
|
96
|
+
Lightweight, task-specific parameterization that modulates
|
|
97
|
+
a shared representation.
|
|
98
|
+
|
|
99
|
+
This is the base class for all task modulators. Subclasses should implement
|
|
100
|
+
the `apply` method to define how the modulator adapts the backbone model
|
|
101
|
+
for a specific task.
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
@abstractmethod
|
|
105
|
+
def apply(self, modulated_model: "ModulatedModel[TorchModelType]"):
|
|
106
|
+
"""
|
|
107
|
+
Apply task-specific modulation to the backbone model.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
modulated_model: The modulated model
|
|
111
|
+
"""
|
|
112
|
+
raise NotImplementedError("Subclasses must implement the apply method.")
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def remove(self, modulated_model: "ModulatedModel[TorchModelType]"):
|
|
116
|
+
"""
|
|
117
|
+
Remove task-specific modulation from the backbone model.
|
|
118
|
+
This is called when switching tasks.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
modulated_model: The modulated model
|
|
122
|
+
"""
|
|
123
|
+
raise NotImplementedError("Subclasses must implement the remove method.")
|