fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/mixins/lightning_fabric.py +110 -11
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -0
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/open_clip/modeling.py +61 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +25 -23
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +20 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +42 -19
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -3,11 +3,16 @@ from copy import deepcopy
|
|
|
3
3
|
from typing import Dict, List, Mapping, Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
|
-
from torch import nn
|
|
6
|
+
from torch import Tensor, nn
|
|
7
7
|
|
|
8
8
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
9
9
|
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
10
10
|
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.models.utils import (
|
|
12
|
+
get_target_state_dict,
|
|
13
|
+
load_state_dict_into_target_modules,
|
|
14
|
+
validate_target_modules_equal,
|
|
15
|
+
)
|
|
11
16
|
from fusion_bench.utils import LazyStateDict
|
|
12
17
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
13
18
|
state_dict_add,
|
|
@@ -21,21 +26,22 @@ log = logging.getLogger(__name__)
|
|
|
21
26
|
|
|
22
27
|
|
|
23
28
|
def simple_average(
|
|
24
|
-
modules: List[Union[nn.Module, StateDictType]],
|
|
25
|
-
base_module: Optional[nn.Module] = None,
|
|
29
|
+
modules: List[Union[nn.Module, StateDictType, Tensor]],
|
|
30
|
+
base_module: Optional[Union[nn.Module, StateDictType, Tensor]] = None,
|
|
26
31
|
):
|
|
27
32
|
R"""
|
|
28
33
|
Averages the parameters of a list of PyTorch modules or state dictionaries.
|
|
29
34
|
|
|
30
35
|
This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
31
36
|
|
|
37
|
+
If `_fusion_bench_target_modules` attribute is set on the modules, only the parameters of the specified target submodules will be averaged.
|
|
38
|
+
|
|
32
39
|
Args:
|
|
33
|
-
modules (List[Union[nn.Module, StateDictType]]): A list of PyTorch modules or state dictionaries.
|
|
34
|
-
base_module (Optional[nn.Module]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
|
|
40
|
+
modules (List[Union[nn.Module, StateDictType, Tensor]]): A list of PyTorch modules or state dictionaries.
|
|
41
|
+
base_module (Optional[Union[nn.Module, StateDictType, Tensor]]): A base module to use for the new module. If provided, the averaged parameters will be loaded into this module. If not provided, a new module will be created by copying the first module in the list.
|
|
35
42
|
|
|
36
43
|
Returns:
|
|
37
|
-
module_or_state_dict (Union[nn.Module, StateDictType]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
38
|
-
|
|
44
|
+
module_or_state_dict (Union[nn.Module, StateDictType, Tensor]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
|
|
39
45
|
Examples:
|
|
40
46
|
>>> import torch.nn as nn
|
|
41
47
|
>>> model1 = nn.Linear(10, 10)
|
|
@@ -47,23 +53,42 @@ def simple_average(
|
|
|
47
53
|
>>> averaged_state_dict = simple_average([state_dict1, state_dict2])
|
|
48
54
|
"""
|
|
49
55
|
assert len(modules) > 0, "modules must be a non-empty list"
|
|
56
|
+
validate_target_modules_equal(modules)
|
|
57
|
+
|
|
50
58
|
if isinstance(modules[0], nn.Module):
|
|
51
59
|
if base_module is None:
|
|
52
60
|
new_module = deepcopy(modules[0])
|
|
53
61
|
else:
|
|
54
62
|
new_module = base_module
|
|
55
|
-
state_dict = state_dict_avg(
|
|
56
|
-
|
|
63
|
+
state_dict = state_dict_avg(
|
|
64
|
+
[get_target_state_dict(module) for module in modules]
|
|
65
|
+
)
|
|
66
|
+
load_state_dict_into_target_modules(new_module, state_dict)
|
|
57
67
|
return new_module
|
|
58
68
|
elif isinstance(modules[0], Mapping):
|
|
59
|
-
|
|
69
|
+
# if the modules are state dicts
|
|
70
|
+
# compute the average state dict
|
|
71
|
+
avg_state_dict = state_dict_avg(modules)
|
|
72
|
+
# load into base_module if provided
|
|
73
|
+
if base_module is not None:
|
|
74
|
+
for k in avg_state_dict:
|
|
75
|
+
base_module[k] = avg_state_dict[k]
|
|
76
|
+
return base_module
|
|
77
|
+
else:
|
|
78
|
+
return avg_state_dict
|
|
79
|
+
elif isinstance(modules[0], Tensor):
|
|
80
|
+
mean_tensor = torch.stack(modules, dim=0).mean(dim=0)
|
|
81
|
+
if base_module is not None:
|
|
82
|
+
base_module.data = mean_tensor
|
|
83
|
+
return base_module
|
|
84
|
+
else:
|
|
85
|
+
return mean_tensor
|
|
86
|
+
else:
|
|
87
|
+
raise ValueError(f"Unsupported type: {type(modules[0])}")
|
|
60
88
|
|
|
61
89
|
|
|
62
90
|
@auto_register_config
|
|
63
|
-
class SimpleAverageAlgorithm(
|
|
64
|
-
SimpleProfilerMixin,
|
|
65
|
-
BaseAlgorithm,
|
|
66
|
-
):
|
|
91
|
+
class SimpleAverageAlgorithm(SimpleProfilerMixin, BaseAlgorithm):
|
|
67
92
|
def __init__(self, show_pbar: bool = False, inplace: bool = True, **kwargs):
|
|
68
93
|
"""
|
|
69
94
|
Args:
|
|
@@ -87,13 +112,20 @@ class SimpleAverageAlgorithm(
|
|
|
87
112
|
Returns:
|
|
88
113
|
The fused model obtained by simple averaging.
|
|
89
114
|
"""
|
|
90
|
-
if isinstance(modelpool,
|
|
115
|
+
if not isinstance(modelpool, BaseModelPool):
|
|
91
116
|
modelpool = BaseModelPool(modelpool)
|
|
92
117
|
|
|
93
118
|
log.info(
|
|
94
119
|
f"Fusing models using simple average on {len(modelpool.model_names)} models. "
|
|
95
120
|
f"models: {modelpool.model_names}"
|
|
96
121
|
)
|
|
122
|
+
if modelpool.has_instance_models and self.inplace:
|
|
123
|
+
log.warning(
|
|
124
|
+
"The model pool contains instance models, and inplace is set to True. "
|
|
125
|
+
"Therefore, the weights of the first model will be overwritten. "
|
|
126
|
+
"If this is desired behavior, this warning can be ignored."
|
|
127
|
+
)
|
|
128
|
+
|
|
97
129
|
sd: Optional[StateDictType] = None
|
|
98
130
|
forward_model = None
|
|
99
131
|
merged_model_names = []
|
|
@@ -106,12 +138,12 @@ class SimpleAverageAlgorithm(
|
|
|
106
138
|
with self.profile("merge weights"):
|
|
107
139
|
if sd is None:
|
|
108
140
|
# Initialize the state dictionary with the first model's state dictionary
|
|
109
|
-
sd = model
|
|
141
|
+
sd = get_target_state_dict(model)
|
|
110
142
|
forward_model = model if self.inplace else deepcopy(model)
|
|
111
143
|
else:
|
|
112
144
|
# Add the current model's state dictionary to the accumulated state dictionary
|
|
113
145
|
sd = state_dict_add(
|
|
114
|
-
sd, model
|
|
146
|
+
sd, get_target_state_dict(model), show_pbar=self.show_pbar
|
|
115
147
|
)
|
|
116
148
|
with self.profile("merge weights"):
|
|
117
149
|
# Divide the accumulated state dictionary by the number of models to get the average
|
|
@@ -124,11 +156,13 @@ class SimpleAverageAlgorithm(
|
|
|
124
156
|
forward_model = deepcopy(forward_model.meta_module).to_empty(
|
|
125
157
|
device=forward_model._device
|
|
126
158
|
)
|
|
127
|
-
|
|
159
|
+
|
|
160
|
+
result = load_state_dict_into_target_modules(forward_model, sd, strict=False)
|
|
128
161
|
if result.unexpected_keys:
|
|
129
162
|
raise ValueError(f"Unexpected keys in state dict: {result.unexpected_keys}")
|
|
130
163
|
if result.missing_keys:
|
|
131
164
|
log.warning(f"Missing keys in state dict: {result.missing_keys}")
|
|
165
|
+
|
|
132
166
|
# print profile report and log the merged models
|
|
133
167
|
self.print_profile_summary()
|
|
134
168
|
log.info(f"merged {len(merged_model_names)} models:")
|
|
@@ -50,7 +50,7 @@ def task_arithmetic_merge(
|
|
|
50
50
|
finetuned_models (List[nn.Module]): A list of fine-tuned models from which task vectors will be calculated.
|
|
51
51
|
scaling_factor (float): A factor by which the task vectors will be scaled before merging.
|
|
52
52
|
inplace (bool, optional): If True, the pre-trained model will be modified in place.
|
|
53
|
-
|
|
53
|
+
If False, a copy of the pre-trained model will be modified. Defaults to True.
|
|
54
54
|
|
|
55
55
|
Returns:
|
|
56
56
|
nn.Module: The pre-trained model with the merged task vectors.
|
|
@@ -249,12 +249,13 @@ class TaskSingularVectorMerging(BaseAlgorithm, LightningFabricMixin):
|
|
|
249
249
|
# - SVD finds the principal components (most important directions)
|
|
250
250
|
# - Task vectors are reconstructed using only the most significant components
|
|
251
251
|
# - The reconstructed vectors are merged (summed) to create a unified task vector
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
252
|
+
with torch.no_grad():
|
|
253
|
+
new_merged_tv = TSVM_utils.compute_and_sum_svd_mem_reduction(
|
|
254
|
+
task_vectors,
|
|
255
|
+
exclude_keys=self.exclude_keys, # Skip certain parameters from SVD
|
|
256
|
+
accelerator=accelerator, # Use GPU if available
|
|
257
|
+
return_single_task_models=self.return_single_task_models,
|
|
258
|
+
)
|
|
258
259
|
|
|
259
260
|
# Handle the case where individual transformed task vectors are also returned
|
|
260
261
|
if self.return_single_task_models:
|
|
@@ -311,7 +311,6 @@ def compute_and_sum_svd_mem_reduction_lossless_eigen(
|
|
|
311
311
|
|
|
312
312
|
###############
|
|
313
313
|
#### TSV Merge Orthogonalization
|
|
314
|
-
@torch.no_grad()
|
|
315
314
|
def compute_and_sum_svd_mem_reduction(
|
|
316
315
|
task_vectors: List[StateDictType],
|
|
317
316
|
exclude_keys: Optional[List[str]] = None,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
+
import sys
|
|
4
5
|
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
|
|
5
6
|
|
|
6
7
|
import lightning as L
|
|
@@ -10,18 +11,34 @@ from lightning.fabric.loggers import TensorBoardLogger
|
|
|
10
11
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
11
12
|
from omegaconf import DictConfig, OmegaConf
|
|
12
13
|
|
|
14
|
+
from fusion_bench.constants import RuntimeConstants
|
|
13
15
|
from fusion_bench.utils import import_object
|
|
16
|
+
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
|
|
14
17
|
from fusion_bench.utils.instantiate_utils import instantiate
|
|
15
18
|
|
|
16
19
|
if TYPE_CHECKING:
|
|
17
20
|
import lightning.fabric.loggers.tensorboard
|
|
18
21
|
from lightning.fabric.strategies import FSDPStrategy
|
|
22
|
+
from lightning.pytorch.loggers import MLFlowLogger
|
|
23
|
+
from mlflow.tracking.client import MlflowClient
|
|
19
24
|
|
|
20
25
|
log = logging.getLogger(__name__)
|
|
21
26
|
|
|
22
27
|
TensorOrModule = TypeVar("TensorOrModule", torch.Tensor, torch.nn.Module, Any)
|
|
23
28
|
|
|
24
29
|
|
|
30
|
+
def _fabric_has_logger(fabric: L.Fabric) -> bool:
|
|
31
|
+
"""
|
|
32
|
+
Check if the fabric has a logger.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
fabric (L.Fabric): The Lightning Fabric instance.
|
|
36
|
+
Returns:
|
|
37
|
+
bool: True if the fabric has a logger, False otherwise.
|
|
38
|
+
"""
|
|
39
|
+
return fabric._loggers is not None and len(fabric._loggers) > 0
|
|
40
|
+
|
|
41
|
+
|
|
25
42
|
def get_policy(*args: str) -> set:
|
|
26
43
|
"""
|
|
27
44
|
Get the policy from the provided list of policy names.
|
|
@@ -42,6 +59,21 @@ def get_size_based_auto_wrap_policy(*args, **kwargs):
|
|
|
42
59
|
return policy
|
|
43
60
|
|
|
44
61
|
|
|
62
|
+
def _is_mlflow_logger(fabric: L.Fabric) -> bool:
|
|
63
|
+
"""
|
|
64
|
+
Check if the fabric's logger is an instance of MLFlowLogger.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
fabric (L.Fabric): The Lightning Fabric instance.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
bool: True if the logger is an instance of MLFlowLogger, False otherwise.
|
|
71
|
+
"""
|
|
72
|
+
if not _fabric_has_logger(fabric):
|
|
73
|
+
return False
|
|
74
|
+
return fabric.logger.__class__.__name__ == "MLFlowLogger"
|
|
75
|
+
|
|
76
|
+
|
|
45
77
|
class LightningFabricMixin:
|
|
46
78
|
"""
|
|
47
79
|
A mixin class for integrating Lightning Fabric into a project.
|
|
@@ -78,8 +110,8 @@ class LightningFabricMixin:
|
|
|
78
110
|
"""
|
|
79
111
|
if self._fabric_instance is None:
|
|
80
112
|
if config.get("fabric", None) is None:
|
|
81
|
-
log.warning("No fabric configuration found. use default settings.")
|
|
82
|
-
self._fabric_instance = L.Fabric()
|
|
113
|
+
log.warning("No fabric configuration found. use default settings. By default, use 1 device.")
|
|
114
|
+
self._fabric_instance = L.Fabric(devices=1)
|
|
83
115
|
else:
|
|
84
116
|
self._fabric_instance = instantiate(config.fabric)
|
|
85
117
|
if not _is_using_cli(): # if not using cli, launch the fabric
|
|
@@ -122,7 +154,10 @@ class LightningFabricMixin:
|
|
|
122
154
|
Retrieves the log directory from the fabric's logger.
|
|
123
155
|
"""
|
|
124
156
|
if self.fabric is not None and len(self.fabric._loggers) > 0:
|
|
125
|
-
|
|
157
|
+
if hasattr(self.fabric.logger, "log_dir"):
|
|
158
|
+
log_dir = self.fabric.logger.log_dir
|
|
159
|
+
else:
|
|
160
|
+
log_dir = None
|
|
126
161
|
|
|
127
162
|
# Special handling for SwanLabLogger to get the correct log directory
|
|
128
163
|
if (
|
|
@@ -131,6 +166,20 @@ class LightningFabricMixin:
|
|
|
131
166
|
):
|
|
132
167
|
log_dir = self.fabric.logger.save_dir or self.fabric.logger._logdir
|
|
133
168
|
|
|
169
|
+
if (
|
|
170
|
+
log_dir is None
|
|
171
|
+
and self.fabric.logger.__class__.__name__ == "MLFlowLogger"
|
|
172
|
+
):
|
|
173
|
+
log_dir = self.fabric.logger.save_dir
|
|
174
|
+
if log_dir is None:
|
|
175
|
+
try:
|
|
176
|
+
log_dir = self._program.config.path.log_dir
|
|
177
|
+
except Exception:
|
|
178
|
+
log.error(
|
|
179
|
+
"Failed to get log_dir from program config for MLFlowLogger."
|
|
180
|
+
)
|
|
181
|
+
log_dir = "outputs"
|
|
182
|
+
|
|
134
183
|
assert log_dir is not None, "log_dir should not be None"
|
|
135
184
|
if self.fabric.is_global_zero and not os.path.exists(log_dir):
|
|
136
185
|
os.makedirs(log_dir, exist_ok=True)
|
|
@@ -206,14 +255,7 @@ class LightningFabricMixin:
|
|
|
206
255
|
Returns:
|
|
207
256
|
bool: True if fast_dev_run is enabled, False otherwise.
|
|
208
257
|
"""
|
|
209
|
-
|
|
210
|
-
return True
|
|
211
|
-
elif hasattr(self, "_program") and self._program.config.get(
|
|
212
|
-
"fast_dev_run", False
|
|
213
|
-
):
|
|
214
|
-
return True
|
|
215
|
-
else:
|
|
216
|
-
return False
|
|
258
|
+
return RuntimeConstants().debug
|
|
217
259
|
|
|
218
260
|
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
219
261
|
"""
|
|
@@ -252,3 +294,60 @@ class LightningFabricMixin:
|
|
|
252
294
|
"""
|
|
253
295
|
for i, param_group in enumerate(optimizer.param_groups):
|
|
254
296
|
self.fabric.log(name_template.format(i), param_group["lr"], step=step)
|
|
297
|
+
|
|
298
|
+
def log_artifact(self, local_path: str, artifact_path: str | None = None):
|
|
299
|
+
"""
|
|
300
|
+
Logs a file as an artifact to the fabric's logger.
|
|
301
|
+
|
|
302
|
+
Args:
|
|
303
|
+
local_dir: The path to the directory to log as an artifact.
|
|
304
|
+
artifact_path: The directory within the logger's artifact storage to save the file.
|
|
305
|
+
"""
|
|
306
|
+
if _is_mlflow_logger(self.fabric):
|
|
307
|
+
logger: "MLFlowLogger" = self.fabric.logger
|
|
308
|
+
experiment: "MlflowClient" = logger.experiment
|
|
309
|
+
experiment.log_artifact(
|
|
310
|
+
logger.run_id,
|
|
311
|
+
local_path=local_path,
|
|
312
|
+
artifact_path=(artifact_path),
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
def log_artifacts(self, local_dir: str, artifact_path: str | None = None):
|
|
316
|
+
"""
|
|
317
|
+
Logs a directory as artifacts to the fabric's logger.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
local_dir: The path to the directory to log as artifacts.
|
|
321
|
+
artifact_path: The directory within the logger's artifact storage to save the files.
|
|
322
|
+
"""
|
|
323
|
+
if _is_mlflow_logger(self.fabric):
|
|
324
|
+
logger: "MLFlowLogger" = self.fabric.logger
|
|
325
|
+
experiment: "MlflowClient" = logger.experiment
|
|
326
|
+
experiment.log_artifacts(
|
|
327
|
+
logger.run_id,
|
|
328
|
+
local_dir=local_dir,
|
|
329
|
+
artifact_path=artifact_path,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
def finalize(self):
|
|
333
|
+
"""
|
|
334
|
+
Destructor to ensure proper cleanup of the Lightning Fabric instance.
|
|
335
|
+
"""
|
|
336
|
+
if self._fabric_instance is None:
|
|
337
|
+
return
|
|
338
|
+
|
|
339
|
+
if _fabric_has_logger(self.fabric) and _is_mlflow_logger(self.fabric):
|
|
340
|
+
if sys.exc_info()[0] is None:
|
|
341
|
+
status = "success"
|
|
342
|
+
else:
|
|
343
|
+
status = "failed"
|
|
344
|
+
self.fabric.logger.finalize(status)
|
|
345
|
+
|
|
346
|
+
del self._fabric_instance
|
|
347
|
+
self._fabric_instance = None
|
|
348
|
+
|
|
349
|
+
def __del__(self):
|
|
350
|
+
"""
|
|
351
|
+
Destructor to ensure proper cleanup of the Lightning Fabric instance.
|
|
352
|
+
"""
|
|
353
|
+
self.finalize()
|
|
@@ -1,11 +1,165 @@
|
|
|
1
|
+
import functools
|
|
1
2
|
import logging
|
|
3
|
+
from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Literal, Optional
|
|
2
4
|
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch.utils.data import DataLoader
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
3
11
|
from fusion_bench.mixins import LightningFabricMixin
|
|
4
|
-
from fusion_bench.
|
|
12
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
13
|
+
from fusion_bench.models.open_clip import (
|
|
14
|
+
ClassificationHead,
|
|
15
|
+
ImageClassifier,
|
|
16
|
+
ImageEncoder,
|
|
17
|
+
)
|
|
18
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
5
19
|
|
|
6
20
|
log = logging.getLogger(__name__)
|
|
7
21
|
|
|
8
22
|
|
|
9
23
|
class OpenCLIPClassificationMixin(LightningFabricMixin):
|
|
24
|
+
|
|
10
25
|
_train_processor = None
|
|
11
26
|
_test_processor = None
|
|
27
|
+
dataloader_kwargs: DictConfig
|
|
28
|
+
modelpool: OpenCLIPVisionModelPool
|
|
29
|
+
zero_shot_heads: Dict[str, ClassificationHead] = {}
|
|
30
|
+
|
|
31
|
+
def _init_processor(self, encoder: Optional["ImageEncoder"] = None):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the CLIP processors for training and testing.
|
|
34
|
+
"""
|
|
35
|
+
if encoder is None:
|
|
36
|
+
encoder: "ImageEncoder" = self.modelpool.load_pretrained_or_first_model()
|
|
37
|
+
self._train_processor = encoder.train_preprocess
|
|
38
|
+
self._test_processor = encoder.val_preprocess
|
|
39
|
+
return self._train_processor, self._test_processor
|
|
40
|
+
|
|
41
|
+
def get_clip_processor(self, stage: Literal["train", "test"]):
|
|
42
|
+
"""
|
|
43
|
+
Get the CLIP processor, loading it from the model pool if necessary.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
CLIPProcessor: The CLIP processor for image and text preprocessing.
|
|
47
|
+
|
|
48
|
+
Raises:
|
|
49
|
+
AssertionError: If the model pool is not set.
|
|
50
|
+
"""
|
|
51
|
+
if stage == "train":
|
|
52
|
+
if self._train_processor is None:
|
|
53
|
+
self._init_processor()
|
|
54
|
+
return self._train_processor
|
|
55
|
+
elif stage == "test":
|
|
56
|
+
if self._test_processor is None:
|
|
57
|
+
self._init_processor()
|
|
58
|
+
return self._test_processor
|
|
59
|
+
else:
|
|
60
|
+
raise ValueError(f"Invalid stage: {stage}")
|
|
61
|
+
|
|
62
|
+
def setup_zero_shot_classification_head(
|
|
63
|
+
self,
|
|
64
|
+
task_names: Optional[List[str]] = None,
|
|
65
|
+
freeze: bool = True,
|
|
66
|
+
dtype: Optional[torch.dtype] = None,
|
|
67
|
+
):
|
|
68
|
+
# check task names consistency across processes
|
|
69
|
+
_task_names = self.fabric.broadcast(task_names, src=0)
|
|
70
|
+
if not self.fabric.is_global_zero and task_names != _task_names:
|
|
71
|
+
raise ValueError("The `task_names` must be the same across all processes.")
|
|
72
|
+
|
|
73
|
+
for task in tqdm(
|
|
74
|
+
self.modelpool.model_names if task_names is None else task_names,
|
|
75
|
+
"Setting up zero-shot classification head",
|
|
76
|
+
disable=not self.fabric.is_global_zero,
|
|
77
|
+
):
|
|
78
|
+
head = self.modelpool.load_classification_head(task)
|
|
79
|
+
if freeze:
|
|
80
|
+
head.requires_grad_(False)
|
|
81
|
+
if dtype is not None:
|
|
82
|
+
head = head.to(dtype=dtype)
|
|
83
|
+
self.zero_shot_heads[task] = self.to_device(head)
|
|
84
|
+
|
|
85
|
+
def set_clip_processor(self, stage: Literal["train", "test"], processor: Callable):
|
|
86
|
+
"""
|
|
87
|
+
Set the CLIP processor for a specific stage.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
stage (Literal["train", "test"]): The stage for which to set the processor.
|
|
91
|
+
processor (Callable): The CLIP processor to set.
|
|
92
|
+
"""
|
|
93
|
+
if stage == "train":
|
|
94
|
+
self._train_processor = processor
|
|
95
|
+
elif stage == "test":
|
|
96
|
+
self._test_processor = processor
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Invalid stage: {stage}")
|
|
99
|
+
|
|
100
|
+
@functools.cache
|
|
101
|
+
def get_shuffled_test_loader_iter(
|
|
102
|
+
self,
|
|
103
|
+
task: str,
|
|
104
|
+
batch_size: Optional[int] = None,
|
|
105
|
+
num_workers: Optional[int] = None,
|
|
106
|
+
**loader_kwargs,
|
|
107
|
+
) -> Iterator:
|
|
108
|
+
"""
|
|
109
|
+
Get an iterator for a shuffled test DataLoader.
|
|
110
|
+
|
|
111
|
+
This method creates a DataLoader for the test dataset of the specified task,
|
|
112
|
+
with shuffling enabled. It allows for optional customization of batch size,
|
|
113
|
+
number of workers, and other DataLoader keyword arguments.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
task (str): The task identifier for which the test dataset is to be loaded.
|
|
117
|
+
batch_size (Optional[int]): The batch size to use for the DataLoader. If None, the default batch size is used.
|
|
118
|
+
num_workers (Optional[int]): The number of worker processes to use for data loading. If None, the default number of workers is used.
|
|
119
|
+
**loader_kwargs: Additional keyword arguments to pass to the DataLoader.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Iterator: An iterator over the shuffled test DataLoader.
|
|
123
|
+
"""
|
|
124
|
+
# get dataloader kwargs
|
|
125
|
+
dataloader_kwargs = self.dataloader_kwargs.copy()
|
|
126
|
+
dataloader_kwargs["shuffle"] = True
|
|
127
|
+
if batch_size is not None:
|
|
128
|
+
dataloader_kwargs["batch_size"] = batch_size
|
|
129
|
+
if num_workers is not None:
|
|
130
|
+
dataloader_kwargs["num_workers"] = num_workers
|
|
131
|
+
dataloader_kwargs.update(loader_kwargs)
|
|
132
|
+
|
|
133
|
+
# get the test dataset
|
|
134
|
+
clip_dataset = CLIPDataset(
|
|
135
|
+
self.modelpool.load_test_dataset(task),
|
|
136
|
+
processor=self.get_clip_processor(stage="test"),
|
|
137
|
+
)
|
|
138
|
+
# create the dataloader
|
|
139
|
+
loader = DataLoader(clip_dataset, **dataloader_kwargs)
|
|
140
|
+
loader = self.fabric.setup_dataloaders(loader)
|
|
141
|
+
return iter(InfiniteDataLoader(loader))
|
|
142
|
+
|
|
143
|
+
def compute_logits(
|
|
144
|
+
self,
|
|
145
|
+
module: ImageClassifier,
|
|
146
|
+
images,
|
|
147
|
+
task: str,
|
|
148
|
+
):
|
|
149
|
+
"""
|
|
150
|
+
Compute the logits for a batch of images using the provided module and task.
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
module (ImageClassifier): The image classification module to use for computing logits.
|
|
154
|
+
images (torch.Tensor): The batch of images for which to compute logits.
|
|
155
|
+
task (str): The task identifier to specify which classification head to use.
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
torch.Tensor: The computed logits for the input images.
|
|
159
|
+
"""
|
|
160
|
+
if len(self.zero_shot_heads) == 0:
|
|
161
|
+
self.setup_zero_shot_classification_head()
|
|
162
|
+
task_head = self.zero_shot_heads[task]
|
|
163
|
+
features = module(images)
|
|
164
|
+
logits = task_head(features)
|
|
165
|
+
return logits
|
|
@@ -68,7 +68,7 @@ def auto_register_config(cls):
|
|
|
68
68
|
|
|
69
69
|
Behavior:
|
|
70
70
|
- **Parameter Registration**: All non-variadic parameters (excluding ``*args``, ``**kwargs``)
|
|
71
|
-
|
|
71
|
+
from the __init__ method are automatically added to _config_mapping
|
|
72
72
|
- **Positional Arguments**: Handled in order and mapped to corresponding parameter names
|
|
73
73
|
- **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
|
|
74
74
|
- **Default Values**: Applied when parameters are not provided via arguments
|
|
@@ -7,10 +7,12 @@ from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
+
from fusion_bench import StateDictType, TorchModelType
|
|
10
11
|
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
12
|
from fusion_bench.utils import (
|
|
12
13
|
ValidationError,
|
|
13
14
|
instantiate,
|
|
15
|
+
state_dict_sub,
|
|
14
16
|
timeit_context,
|
|
15
17
|
validate_model_name,
|
|
16
18
|
)
|
|
@@ -56,6 +58,10 @@ class BaseModelPool(
|
|
|
56
58
|
**kwargs,
|
|
57
59
|
):
|
|
58
60
|
if isinstance(models, List):
|
|
61
|
+
log.debug(
|
|
62
|
+
"Initializing BaseModelPool with a list of models. "
|
|
63
|
+
"Converting to a dictionary with integer string keys."
|
|
64
|
+
)
|
|
59
65
|
models = {str(model_idx): model for model_idx, model in enumerate(models)}
|
|
60
66
|
|
|
61
67
|
if isinstance(models, dict):
|
|
@@ -80,6 +86,22 @@ class BaseModelPool(
|
|
|
80
86
|
self._test_datasets = test_datasets
|
|
81
87
|
super().__init__(**kwargs)
|
|
82
88
|
|
|
89
|
+
@property
|
|
90
|
+
def has_instance_models(self) -> bool:
|
|
91
|
+
"""
|
|
92
|
+
Check if the model pool contains any pre-instantiated models.
|
|
93
|
+
|
|
94
|
+
Attention:
|
|
95
|
+
Some algorithms may modify the models in-place if they are pre-instantiated.
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
bool: True if there are pre-instantiated models, False otherwise.
|
|
99
|
+
"""
|
|
100
|
+
for model_cfg in self._models.values():
|
|
101
|
+
if isinstance(model_cfg, nn.Module):
|
|
102
|
+
return True
|
|
103
|
+
return False
|
|
104
|
+
|
|
83
105
|
@property
|
|
84
106
|
def has_pretrained(self) -> bool:
|
|
85
107
|
"""
|
|
@@ -328,6 +350,21 @@ class BaseModelPool(
|
|
|
328
350
|
for model_name in self.model_names:
|
|
329
351
|
yield model_name, self.load_model(model_name)
|
|
330
352
|
|
|
353
|
+
def load_pretrained_model_and_task_vectors(
|
|
354
|
+
self,
|
|
355
|
+
) -> Tuple[TorchModelType, List[StateDictType]]:
|
|
356
|
+
pretrained_model = self.load_pretrained_model()
|
|
357
|
+
|
|
358
|
+
task_vectors = []
|
|
359
|
+
for model_name in self.model_names:
|
|
360
|
+
finetuned_model = self.load_model(model_name)
|
|
361
|
+
task_vector = state_dict_sub(
|
|
362
|
+
finetuned_model.state_dict(), pretrained_model.state_dict()
|
|
363
|
+
)
|
|
364
|
+
task_vectors.append(task_vector)
|
|
365
|
+
|
|
366
|
+
return pretrained_model, task_vectors
|
|
367
|
+
|
|
331
368
|
@property
|
|
332
369
|
def has_train_dataset(self) -> bool:
|
|
333
370
|
"""
|
|
@@ -98,7 +98,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
|
|
|
98
98
|
- Load ConvNeXt models either from a pretrained checkpoint or from config.
|
|
99
99
|
- Optionally adapt the classifier head to match dataset classnames.
|
|
100
100
|
- Override `forward` to return logits for consistent interfaces within
|
|
101
|
-
|
|
101
|
+
FusionBench.
|
|
102
102
|
|
|
103
103
|
See `fusion_bench.modelpool.resnet_for_image_classification` for a closely
|
|
104
104
|
related ResNet-based pool with analogous behavior.
|
|
@@ -161,6 +161,9 @@ class ConvNextForImageClassificationPool(BaseModelPool):
|
|
|
161
161
|
).logits
|
|
162
162
|
model.original_forward = original_forward
|
|
163
163
|
|
|
164
|
+
# Mark ConvNeXt layers for FusionBench fusion
|
|
165
|
+
model._fusion_bench_target_modules = ["convnext"]
|
|
166
|
+
|
|
164
167
|
return model
|
|
165
168
|
|
|
166
169
|
@override
|
|
@@ -180,7 +183,7 @@ class ConvNextForImageClassificationPool(BaseModelPool):
|
|
|
180
183
|
- The ConvNeXt model via `model.save_pretrained`.
|
|
181
184
|
- The paired image processor via `AutoImageProcessor.save_pretrained`.
|
|
182
185
|
- If `algorithm_config` is provided and on rank-zero, a README model card
|
|
183
|
-
|
|
186
|
+
documenting the FusionBench configuration.
|
|
184
187
|
"""
|
|
185
188
|
model.save_pretrained(path)
|
|
186
189
|
self.load_processor().save_pretrained(path)
|