fusion-bench 0.2.28__py3-none-any.whl → 0.2.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +10 -2
- fusion_bench/method/base_algorithm.py +29 -19
- fusion_bench/method/classification/image_classification_finetune.py +1 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/metrics/model_kinship/__init__.py +2 -0
- fusion_bench/metrics/model_kinship/calculate.py +77 -0
- fusion_bench/metrics/model_kinship/calculate_split.py +171 -0
- fusion_bench/metrics/model_kinship/utility.py +184 -0
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/base_pool.py +86 -5
- fusion_bench/models/masks/mask_model.py +8 -2
- fusion_bench/models/open_clip/modeling.py +7 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +41 -3
- fusion_bench/models/wrappers/task_wise_fusion.py +14 -3
- fusion_bench/scripts/cli.py +14 -0
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/devices.py +3 -1
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/instantiate_utils.py +29 -18
- fusion_bench/utils/json.py +6 -0
- fusion_bench/utils/misc.py +16 -0
- fusion_bench/utils/rich_utils.py +123 -6
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/METADATA +72 -13
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/RECORD +49 -45
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.28.dist-info → fusion_bench-0.2.30.dist-info}/top_level.txt +0 -0
|
@@ -59,6 +59,15 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
59
59
|
|
|
60
60
|
@property
|
|
61
61
|
def clip_processor(self):
|
|
62
|
+
"""
|
|
63
|
+
Get the CLIP processor, loading it from the model pool if necessary.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
CLIPProcessor: The CLIP processor for image and text preprocessing.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
AssertionError: If the model pool is not set.
|
|
70
|
+
"""
|
|
62
71
|
if self._clip_processor is None:
|
|
63
72
|
assert self.modelpool is not None, "Model pool is not set"
|
|
64
73
|
self._clip_processor = self.modelpool.load_processor()
|
|
@@ -125,6 +134,11 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
125
134
|
clip_model (Optional[CLIPModel]): The CLIP model to use. If not provided, a pretrained model is loaded from the model pool.
|
|
126
135
|
task_names (Optional[List[str]]): A list of task names to set up the classification head for. If not provided, all models in the model pool will be used.
|
|
127
136
|
"""
|
|
137
|
+
# make sure the task names are equal across all processes
|
|
138
|
+
_task_names = self.fabric.broadcast(task_names, src=0)
|
|
139
|
+
if not self.fabric.is_global_zero and task_names != _task_names:
|
|
140
|
+
raise ValueError("The `task_names` must be the same across all processes.")
|
|
141
|
+
|
|
128
142
|
self.whether_setup_zero_shot_classification_head = True
|
|
129
143
|
# load clip model if not provided
|
|
130
144
|
if clip_model is None:
|
|
@@ -147,7 +161,10 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
147
161
|
self.logit_scale_exp = self.fabric.to_device(self.logit_scale_exp)
|
|
148
162
|
|
|
149
163
|
@cache_with_joblib()
|
|
150
|
-
def construct_classification_head(task: str):
|
|
164
|
+
def construct_classification_head(task: str, model_name: str):
|
|
165
|
+
log.info(
|
|
166
|
+
f"Constructing zero-shot classification head for task: {task} using model: {model_name}"
|
|
167
|
+
)
|
|
151
168
|
nonlocal clip_classifier
|
|
152
169
|
|
|
153
170
|
classnames, templates = get_classnames_and_templates(task)
|
|
@@ -163,7 +180,18 @@ class CLIPClassificationMixin(LightningFabricMixin):
|
|
|
163
180
|
):
|
|
164
181
|
zeroshot_weights = None
|
|
165
182
|
if self.fabric.is_global_zero:
|
|
166
|
-
|
|
183
|
+
if hasattr(clip_model, "config") and hasattr(
|
|
184
|
+
clip_model.config, "_name_or_path"
|
|
185
|
+
):
|
|
186
|
+
model_name = clip_model.config._name_or_path
|
|
187
|
+
else:
|
|
188
|
+
model_name = "unknown_model"
|
|
189
|
+
log.warning(
|
|
190
|
+
"CLIP model config does not have `_name_or_path` attribute. Using 'unknown_model' as model name."
|
|
191
|
+
)
|
|
192
|
+
zeroshot_weights = construct_classification_head(
|
|
193
|
+
task, model_name=model_name
|
|
194
|
+
)
|
|
167
195
|
|
|
168
196
|
self.fabric.barrier()
|
|
169
197
|
self.zeroshot_weights[task] = self.fabric.broadcast(zeroshot_weights, src=0)
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
from typing import TYPE_CHECKING, Any, List, Optional, TypeVar
|
|
4
|
+
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, TypeVar
|
|
5
5
|
|
|
6
6
|
import lightning as L
|
|
7
7
|
import torch
|
|
@@ -96,12 +96,24 @@ class LightningFabricMixin:
|
|
|
96
96
|
|
|
97
97
|
@property
|
|
98
98
|
def fabric(self):
|
|
99
|
+
"""
|
|
100
|
+
Get the Lightning Fabric instance, initializing it if necessary.
|
|
101
|
+
|
|
102
|
+
Returns:
|
|
103
|
+
L.Fabric: The Lightning Fabric instance for distributed computing.
|
|
104
|
+
"""
|
|
99
105
|
if self._fabric_instance is None:
|
|
100
106
|
self.setup_lightning_fabric(getattr(self, "config", DictConfig({})))
|
|
101
107
|
return self._fabric_instance
|
|
102
108
|
|
|
103
109
|
@fabric.setter
|
|
104
110
|
def fabric(self, instance: L.Fabric):
|
|
111
|
+
"""
|
|
112
|
+
Set the Lightning Fabric instance.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
instance: The Lightning Fabric instance to use.
|
|
116
|
+
"""
|
|
105
117
|
self._fabric_instance = instance
|
|
106
118
|
|
|
107
119
|
@property
|
|
@@ -172,6 +184,15 @@ class LightningFabricMixin:
|
|
|
172
184
|
def tensorboard_summarywriter(
|
|
173
185
|
self,
|
|
174
186
|
) -> "lightning.fabric.loggers.tensorboard.SummaryWriter":
|
|
187
|
+
"""
|
|
188
|
+
Get the TensorBoard SummaryWriter for detailed logging.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
SummaryWriter: The TensorBoard SummaryWriter instance.
|
|
192
|
+
|
|
193
|
+
Raises:
|
|
194
|
+
AttributeError: If the logger is not a TensorBoardLogger.
|
|
195
|
+
"""
|
|
175
196
|
if isinstance(self.fabric.logger, TensorBoardLogger):
|
|
176
197
|
return self.fabric.logger.experiment
|
|
177
198
|
else:
|
|
@@ -179,6 +200,12 @@ class LightningFabricMixin:
|
|
|
179
200
|
|
|
180
201
|
@property
|
|
181
202
|
def is_debug_mode(self):
|
|
203
|
+
"""
|
|
204
|
+
Check if the program is running in debug mode (fast_dev_run).
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
bool: True if fast_dev_run is enabled, False otherwise.
|
|
208
|
+
"""
|
|
182
209
|
if hasattr(self, "config") and self.config.get("fast_dev_run", False):
|
|
183
210
|
return True
|
|
184
211
|
elif hasattr(self, "_program") and self._program.config.get(
|
|
@@ -190,13 +217,22 @@ class LightningFabricMixin:
|
|
|
190
217
|
|
|
191
218
|
def log(self, name: str, value: Any, step: Optional[int] = None):
|
|
192
219
|
"""
|
|
193
|
-
Logs
|
|
220
|
+
Logs a single metric to the fabric's logger.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
name: The name of the metric to log.
|
|
224
|
+
value: The value of the metric.
|
|
225
|
+
step: Optional step number for the metric.
|
|
194
226
|
"""
|
|
195
227
|
self.fabric.log(name, value, step=step)
|
|
196
228
|
|
|
197
|
-
def log_dict(self, metrics:
|
|
229
|
+
def log_dict(self, metrics: Mapping[str, Any], step: Optional[int] = None):
|
|
198
230
|
"""
|
|
199
|
-
Logs
|
|
231
|
+
Logs multiple metrics to the fabric's logger.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
metrics: Dictionary of metric names and values.
|
|
235
|
+
step: Optional step number for the metrics.
|
|
200
236
|
"""
|
|
201
237
|
self.fabric.log_dict(metrics, step=step)
|
|
202
238
|
|
|
@@ -207,7 +243,12 @@ class LightningFabricMixin:
|
|
|
207
243
|
name_template: str = "train/lr_group_{0}",
|
|
208
244
|
):
|
|
209
245
|
"""
|
|
210
|
-
Logs the learning rate of
|
|
246
|
+
Logs the learning rate of each parameter group in the optimizer.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
optimizer: The optimizer whose learning rates should be logged.
|
|
250
|
+
step: Optional step number for the log entry.
|
|
251
|
+
name_template: Template string for the log name. Use {0} as placeholder for group index.
|
|
211
252
|
"""
|
|
212
253
|
for i, param_group in enumerate(optimizer.param_groups):
|
|
213
254
|
self.fabric.log(name_template.format(i), param_group["lr"], step=step)
|
fusion_bench/mixins/rich_live.py
CHANGED
|
@@ -2,20 +2,96 @@ from rich.live import Live
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
class RichLiveMixin:
|
|
5
|
+
"""
|
|
6
|
+
A mixin class that provides Rich Live display capabilities.
|
|
7
|
+
|
|
8
|
+
This mixin integrates Rich's Live display functionality, allowing for
|
|
9
|
+
dynamic, auto-refreshing console output. It's particularly useful for
|
|
10
|
+
displaying real-time updates, progress information, or continuously
|
|
11
|
+
changing data without cluttering the terminal.
|
|
12
|
+
|
|
13
|
+
Attributes:
|
|
14
|
+
_rich_live (Live): The internal Rich Live instance for live display updates.
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
```python
|
|
18
|
+
class MyTask(RichLiveMixin):
|
|
19
|
+
def run(self):
|
|
20
|
+
self.start_rich_live()
|
|
21
|
+
for i in range(100):
|
|
22
|
+
self.rich_live_print(f"Processing item {i}")
|
|
23
|
+
time.sleep(0.1)
|
|
24
|
+
self.stop_rich_live()
|
|
25
|
+
```
|
|
26
|
+
"""
|
|
27
|
+
|
|
5
28
|
_rich_live: Live = None
|
|
6
29
|
|
|
7
30
|
@property
|
|
8
31
|
def rich_live(self) -> Live:
|
|
32
|
+
"""
|
|
33
|
+
Get the Rich Live instance, creating it if necessary.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
Live: The Rich Live instance for dynamic console output.
|
|
37
|
+
"""
|
|
9
38
|
if self._rich_live is None:
|
|
10
39
|
self._rich_live = Live()
|
|
11
40
|
return self._rich_live
|
|
12
41
|
|
|
13
42
|
def start_rich_live(self):
|
|
43
|
+
"""
|
|
44
|
+
Start the Rich Live display context.
|
|
45
|
+
|
|
46
|
+
This method enters the Rich Live context, enabling dynamic console output.
|
|
47
|
+
Must be paired with stop_rich_live() to properly clean up resources.
|
|
48
|
+
|
|
49
|
+
Returns:
|
|
50
|
+
The Rich Live instance in its started state.
|
|
51
|
+
|
|
52
|
+
Example:
|
|
53
|
+
```python
|
|
54
|
+
self.start_rich_live()
|
|
55
|
+
# Display dynamic content
|
|
56
|
+
self.rich_live_print("Dynamic output")
|
|
57
|
+
self.stop_rich_live()
|
|
58
|
+
```
|
|
59
|
+
"""
|
|
14
60
|
return self.rich_live.__enter__()
|
|
15
61
|
|
|
16
62
|
def stop_rich_live(self):
|
|
63
|
+
"""
|
|
64
|
+
Stop the Rich Live display context and clean up resources.
|
|
65
|
+
|
|
66
|
+
This method exits the Rich Live context and resets the internal Live instance.
|
|
67
|
+
Should be called after start_rich_live() when dynamic display is complete.
|
|
68
|
+
|
|
69
|
+
Example:
|
|
70
|
+
```python
|
|
71
|
+
self.start_rich_live()
|
|
72
|
+
# ... display content ...
|
|
73
|
+
self.stop_rich_live()
|
|
74
|
+
```
|
|
75
|
+
"""
|
|
17
76
|
self.rich_live.__exit__(None, None, None)
|
|
18
77
|
self._rich_live = None
|
|
19
78
|
|
|
20
79
|
def rich_live_print(self, msg):
|
|
80
|
+
"""
|
|
81
|
+
Print a message to the Rich Live console.
|
|
82
|
+
|
|
83
|
+
This method displays the given message through the Rich Live console,
|
|
84
|
+
allowing for formatted, dynamic output that updates in place.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
msg: The message to display. Can be a string or any Rich renderable object.
|
|
88
|
+
|
|
89
|
+
Example:
|
|
90
|
+
```python
|
|
91
|
+
self.start_rich_live()
|
|
92
|
+
self.rich_live_print("[bold green]Success![/bold green]")
|
|
93
|
+
self.rich_live_print(Panel("Status: Running"))
|
|
94
|
+
self.stop_rich_live()
|
|
95
|
+
```
|
|
96
|
+
"""
|
|
21
97
|
self.rich_live.console.print(msg)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
3
|
+
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from omegaconf import DictConfig, OmegaConf, UnsupportedValueType
|
|
@@ -8,7 +8,12 @@ from torch import nn
|
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
10
|
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
|
-
from fusion_bench.utils import
|
|
11
|
+
from fusion_bench.utils import (
|
|
12
|
+
ValidationError,
|
|
13
|
+
instantiate,
|
|
14
|
+
timeit_context,
|
|
15
|
+
validate_model_name,
|
|
16
|
+
)
|
|
12
17
|
|
|
13
18
|
__all__ = ["BaseModelPool"]
|
|
14
19
|
|
|
@@ -59,6 +64,16 @@ class BaseModelPool(
|
|
|
59
64
|
except UnsupportedValueType:
|
|
60
65
|
pass
|
|
61
66
|
|
|
67
|
+
if not models:
|
|
68
|
+
log.warning("Initialized BaseModelPool with empty models dictionary.")
|
|
69
|
+
else:
|
|
70
|
+
# Validate model names
|
|
71
|
+
for model_name in models.keys():
|
|
72
|
+
try:
|
|
73
|
+
validate_model_name(model_name, allow_special=True)
|
|
74
|
+
except ValidationError as e:
|
|
75
|
+
log.warning(f"Invalid model name '{model_name}': {e}")
|
|
76
|
+
|
|
62
77
|
self._models = models
|
|
63
78
|
self._train_datasets = train_datasets
|
|
64
79
|
self._val_datasets = val_datasets
|
|
@@ -147,7 +162,9 @@ class BaseModelPool(
|
|
|
147
162
|
"""
|
|
148
163
|
return model_name.startswith("_") and model_name.endswith("_")
|
|
149
164
|
|
|
150
|
-
def get_model_config(
|
|
165
|
+
def get_model_config(
|
|
166
|
+
self, model_name: str, return_copy: bool = True
|
|
167
|
+
) -> Union[DictConfig, str, Any]:
|
|
151
168
|
"""
|
|
152
169
|
Get the configuration for the specified model.
|
|
153
170
|
|
|
@@ -155,10 +172,36 @@ class BaseModelPool(
|
|
|
155
172
|
model_name (str): The name of the model.
|
|
156
173
|
|
|
157
174
|
Returns:
|
|
158
|
-
DictConfig: The configuration for the specified model.
|
|
175
|
+
Union[DictConfig, str, Any]: The configuration for the specified model, which may be a DictConfig, string path, or other type.
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
ValidationError: If model_name is invalid.
|
|
179
|
+
KeyError: If model_name is not found in the pool.
|
|
159
180
|
"""
|
|
181
|
+
# Validate model name
|
|
182
|
+
validate_model_name(model_name, allow_special=True)
|
|
183
|
+
|
|
184
|
+
# raise friendly error if model not found in the pool
|
|
185
|
+
if model_name not in self._models:
|
|
186
|
+
available_models = list(self._models.keys())
|
|
187
|
+
raise KeyError(
|
|
188
|
+
f"Model '{model_name}' not found in model pool. "
|
|
189
|
+
f"Available models: {available_models}"
|
|
190
|
+
)
|
|
191
|
+
|
|
160
192
|
model_config = self._models[model_name]
|
|
193
|
+
if isinstance(model_config, nn.Module):
|
|
194
|
+
log.warning(
|
|
195
|
+
f"Model configuration for '{model_name}' is a pre-instantiated model. "
|
|
196
|
+
"Returning the model instance instead of configuration."
|
|
197
|
+
)
|
|
198
|
+
|
|
161
199
|
if return_copy:
|
|
200
|
+
if isinstance(model_config, nn.Module):
|
|
201
|
+
# raise performance warning
|
|
202
|
+
log.warning(
|
|
203
|
+
f"Furthermore, returning a copy of the pre-instantiated model '{model_name}' may be inefficient."
|
|
204
|
+
)
|
|
162
205
|
model_config = deepcopy(model_config)
|
|
163
206
|
return model_config
|
|
164
207
|
|
|
@@ -171,12 +214,28 @@ class BaseModelPool(
|
|
|
171
214
|
|
|
172
215
|
Returns:
|
|
173
216
|
str: The path for the specified model.
|
|
217
|
+
|
|
218
|
+
Raises:
|
|
219
|
+
ValidationError: If model_name is invalid.
|
|
220
|
+
KeyError: If model_name is not found in the pool.
|
|
221
|
+
ValueError: If model configuration is not a string path.
|
|
174
222
|
"""
|
|
223
|
+
# Validate model name
|
|
224
|
+
validate_model_name(model_name, allow_special=True)
|
|
225
|
+
|
|
226
|
+
if model_name not in self._models:
|
|
227
|
+
available_models = list(self._models.keys())
|
|
228
|
+
raise KeyError(
|
|
229
|
+
f"Model '{model_name}' not found in model pool. "
|
|
230
|
+
f"Available models: {available_models}"
|
|
231
|
+
)
|
|
232
|
+
|
|
175
233
|
if isinstance(self._models[model_name], str):
|
|
176
234
|
return self._models[model_name]
|
|
177
235
|
else:
|
|
178
236
|
raise ValueError(
|
|
179
|
-
"Model
|
|
237
|
+
f"Model configuration for '{model_name}' is not a string path. "
|
|
238
|
+
"Try to override this method in derived modelpool class."
|
|
180
239
|
)
|
|
181
240
|
|
|
182
241
|
def load_model(
|
|
@@ -357,3 +416,25 @@ class BaseModelPool(
|
|
|
357
416
|
"""
|
|
358
417
|
with timeit_context(f"Saving the state dict of model to {path}"):
|
|
359
418
|
torch.save(model.state_dict(), path)
|
|
419
|
+
|
|
420
|
+
def __contains__(self, model_name: str) -> bool:
|
|
421
|
+
"""
|
|
422
|
+
Check if a model with the given name exists in the model pool.
|
|
423
|
+
|
|
424
|
+
Examples:
|
|
425
|
+
>>> modelpool = BaseModelPool(models={"modelA": ..., "modelB": ...})
|
|
426
|
+
>>> "modelA" in modelpool
|
|
427
|
+
True
|
|
428
|
+
>>> "modelC" in modelpool
|
|
429
|
+
False
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
model_name (str): The name of the model to check.
|
|
433
|
+
|
|
434
|
+
Returns:
|
|
435
|
+
bool: True if the model exists, False otherwise.
|
|
436
|
+
"""
|
|
437
|
+
if self._models is None:
|
|
438
|
+
raise RuntimeError("Model pool is not initialized")
|
|
439
|
+
validate_model_name(model_name, allow_special=True)
|
|
440
|
+
return model_name in self._models
|
|
@@ -113,21 +113,27 @@ class MaskModel(ParameterDictModel):
|
|
|
113
113
|
def get_distribution(
|
|
114
114
|
self,
|
|
115
115
|
mask_type: Literal["discrete", "continuous"],
|
|
116
|
+
temperature: float = 0.5,
|
|
116
117
|
**kwargs,
|
|
117
118
|
):
|
|
118
119
|
return {
|
|
119
|
-
name: self._param_to_distribution(
|
|
120
|
+
name: self._param_to_distribution(
|
|
121
|
+
param, mask_type=mask_type, temperature=temperature, **kwargs
|
|
122
|
+
)
|
|
120
123
|
for name, param in self.named_parameters()
|
|
121
124
|
}
|
|
122
125
|
|
|
123
126
|
def sample_mask(
|
|
124
127
|
self,
|
|
125
128
|
mask_type: Literal["discrete", "continuous"] = "discrete",
|
|
129
|
+
temperature: float = 0.5,
|
|
126
130
|
**kwargs,
|
|
127
131
|
):
|
|
128
132
|
mask = {}
|
|
129
133
|
for name, param in self.named_parameters():
|
|
130
|
-
dist = self._param_to_distribution(
|
|
134
|
+
dist = self._param_to_distribution(
|
|
135
|
+
param, mask_type, temperature=temperature, **kwargs
|
|
136
|
+
)
|
|
131
137
|
if mask_type == "discrete":
|
|
132
138
|
mask[name] = dist.sample()
|
|
133
139
|
elif mask_type == "continuous":
|
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
from fusion_bench.utils.packages import is_open_clip_available
|
|
2
|
+
|
|
3
|
+
if not is_open_clip_available():
|
|
4
|
+
raise ImportError(
|
|
5
|
+
"open_clip is not installed. Please install it with `pip install open_clip_torch`."
|
|
6
|
+
)
|
|
7
|
+
|
|
1
8
|
from typing import Callable, List
|
|
2
9
|
|
|
3
10
|
import open_clip
|
|
@@ -173,6 +173,24 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
173
173
|
|
|
174
174
|
@property
|
|
175
175
|
def forward_model(self):
|
|
176
|
+
"""
|
|
177
|
+
Get a functional model with merged parameters.
|
|
178
|
+
|
|
179
|
+
Returns a partial function that applies the pretrained model with the current
|
|
180
|
+
merged state dictionary. This allows for efficient forward passes without
|
|
181
|
+
modifying the original model's parameters.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
Callable: A partial function that can be called with (args, kwargs) to
|
|
185
|
+
perform forward pass with merged parameters.
|
|
186
|
+
|
|
187
|
+
Example:
|
|
188
|
+
```python
|
|
189
|
+
# Internal usage during forward pass
|
|
190
|
+
forward_fn = merged_model.forward_model
|
|
191
|
+
output = forward_fn(args=(x,), kwargs={})
|
|
192
|
+
```
|
|
193
|
+
"""
|
|
176
194
|
return functools.partial(
|
|
177
195
|
functional_call,
|
|
178
196
|
self.pretrained_model,
|
|
@@ -181,10 +199,30 @@ class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
181
199
|
strict=self.strict,
|
|
182
200
|
)
|
|
183
201
|
|
|
184
|
-
def merge_and_unload(
|
|
202
|
+
def merge_and_unload(
|
|
203
|
+
self,
|
|
204
|
+
task_vector_mask: Optional[Dict[str, Tensor]] = None,
|
|
205
|
+
copy: bool = False,
|
|
206
|
+
) -> TorchModelType:
|
|
207
|
+
"""
|
|
208
|
+
Merge models and return the final merged model.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
|
|
212
|
+
for selective parameter merging. Defaults to None.
|
|
213
|
+
copy (bool, optional): Whether to return a deep copy of the pretrained model.
|
|
214
|
+
Defaults to False. If True, the original pretrained model remains unchanged.
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
TorchModelType: The pretrained model with merged parameters loaded.
|
|
218
|
+
"""
|
|
185
219
|
self.merge_weights(task_vector_mask=task_vector_mask)
|
|
186
|
-
|
|
187
|
-
|
|
220
|
+
if copy:
|
|
221
|
+
model = deepcopy(self.pretrained_model)
|
|
222
|
+
else:
|
|
223
|
+
model = self.pretrained_model
|
|
224
|
+
model.load_state_dict(self._merged_state_dict)
|
|
225
|
+
return model
|
|
188
226
|
|
|
189
227
|
def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
190
228
|
"""
|
|
@@ -16,6 +16,7 @@ outputs = merged_model(inputs)
|
|
|
16
16
|
|
|
17
17
|
import functools
|
|
18
18
|
import logging
|
|
19
|
+
from copy import deepcopy
|
|
19
20
|
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
|
|
20
21
|
|
|
21
22
|
import torch
|
|
@@ -327,7 +328,11 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
327
328
|
self._merged_state_dict = state_dict
|
|
328
329
|
return state_dict
|
|
329
330
|
|
|
330
|
-
def merge_and_unload(
|
|
331
|
+
def merge_and_unload(
|
|
332
|
+
self,
|
|
333
|
+
task_vector_mask: Optional[Dict[str, Tensor]] = None,
|
|
334
|
+
copy: bool = False,
|
|
335
|
+
) -> TorchModelType:
|
|
331
336
|
"""
|
|
332
337
|
Merge models and return the final merged model.
|
|
333
338
|
|
|
@@ -338,6 +343,8 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
338
343
|
Args:
|
|
339
344
|
task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
|
|
340
345
|
for selective parameter merging. Defaults to None.
|
|
346
|
+
copy (bool, optional): Whether to return a deep copy of the pretrained model.
|
|
347
|
+
Defaults to False. If True, the original pretrained model remains unchanged.
|
|
341
348
|
|
|
342
349
|
Returns:
|
|
343
350
|
TorchModelType: The pretrained model with merged parameters loaded.
|
|
@@ -363,8 +370,12 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
363
370
|
The original pretrained model parameters will be lost.
|
|
364
371
|
"""
|
|
365
372
|
self.merge_weights(task_vector_mask=task_vector_mask)
|
|
366
|
-
|
|
367
|
-
|
|
373
|
+
if copy:
|
|
374
|
+
model = deepcopy(self.pretrained_model)
|
|
375
|
+
else:
|
|
376
|
+
model = self.pretrained_model
|
|
377
|
+
model.load_state_dict(self._merged_state_dict)
|
|
378
|
+
return model
|
|
368
379
|
|
|
369
380
|
def forward(self, *args, **kwargs):
|
|
370
381
|
"""
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -69,6 +69,20 @@ def main(cfg: DictConfig) -> None:
|
|
|
69
69
|
"""
|
|
70
70
|
OmegaConf.resolve(cfg)
|
|
71
71
|
program: BaseHydraProgram = instantiate(cfg)
|
|
72
|
+
|
|
73
|
+
# Validate that instantiation succeeded and returned an object with 'run' method
|
|
74
|
+
if not hasattr(program, "run") or not callable(getattr(program, "run")):
|
|
75
|
+
err_msg = (
|
|
76
|
+
f"Expected an object with a callable 'run' method, but got {type(program).__name__}. "
|
|
77
|
+
"Ensure that the configuration specifies a concrete program class with '_target_'."
|
|
78
|
+
)
|
|
79
|
+
if "_target_" not in cfg:
|
|
80
|
+
err_msg += "\nThe '_target_' field is missing from the root configuration."
|
|
81
|
+
else:
|
|
82
|
+
err_msg += f"\nFound '_target_': {cfg._target_}"
|
|
83
|
+
err_msg += f"\n\nConfiguration content:\n{cfg}"
|
|
84
|
+
raise TypeError(err_msg)
|
|
85
|
+
|
|
72
86
|
program.run()
|
|
73
87
|
|
|
74
88
|
|