fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +108 -3
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/scripts/cli.py +19 -8
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +18 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig, OmegaConf
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
|
11
|
+
from torchmetrics import Accuracy, MeanMetric
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
|
+
|
|
14
|
+
from fusion_bench import (
|
|
15
|
+
BaseTaskPool,
|
|
16
|
+
LightningFabricMixin,
|
|
17
|
+
RuntimeConstants,
|
|
18
|
+
auto_register_config,
|
|
19
|
+
get_rankzero_logger,
|
|
20
|
+
instantiate,
|
|
21
|
+
)
|
|
22
|
+
from fusion_bench.dataset import ImageClassificationDataset
|
|
23
|
+
from fusion_bench.models.wrappers.switch import set_active_option
|
|
24
|
+
from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
|
|
25
|
+
from fusion_bench.utils import count_parameters
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from transformers import AutoModelForImageClassification
|
|
29
|
+
|
|
30
|
+
log = get_rankzero_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _get_logits_from_model_output(outputs) -> Tensor:
|
|
34
|
+
"""Extract logits from model outputs."""
|
|
35
|
+
match outputs:
|
|
36
|
+
case Tensor():
|
|
37
|
+
logits = outputs
|
|
38
|
+
case dict() | DictConfig() if "logits" in outputs:
|
|
39
|
+
logits = outputs["logits"]
|
|
40
|
+
assert isinstance(
|
|
41
|
+
logits, Tensor
|
|
42
|
+
), "The 'logits' in the model output dictionary is not a Tensor."
|
|
43
|
+
case _:
|
|
44
|
+
if hasattr(outputs, "logits"):
|
|
45
|
+
logits = outputs.logits
|
|
46
|
+
assert isinstance(
|
|
47
|
+
logits, Tensor
|
|
48
|
+
), "The 'logits' attribute of the model output is not a Tensor."
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Model output is not a Tensor and does not have 'logits' attribute."
|
|
52
|
+
)
|
|
53
|
+
return logits
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@auto_register_config
|
|
57
|
+
class ImageClassificationTaskPool(LightningFabricMixin, BaseTaskPool):
|
|
58
|
+
_config_mapping = BaseTaskPool._config_mapping | {
|
|
59
|
+
"_test_datasets": "test_datasets",
|
|
60
|
+
"_processor": "processor",
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
_processor_instance = None
|
|
64
|
+
_is_setup: bool = False
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
test_datasets: DictConfig | Dict[str, Any],
|
|
69
|
+
*,
|
|
70
|
+
processor: DictConfig | Any,
|
|
71
|
+
dataloader_kwargs: DictConfig,
|
|
72
|
+
**kwargs,
|
|
73
|
+
):
|
|
74
|
+
super().__init__(**kwargs)
|
|
75
|
+
|
|
76
|
+
# if the processor is given as a transformers processor instance, store it directly
|
|
77
|
+
if callable(processor):
|
|
78
|
+
self._processor_instance = processor
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def processor(self) -> Any:
|
|
82
|
+
if self._processor is None:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
if self._processor_instance is not None:
|
|
86
|
+
return self._processor_instance
|
|
87
|
+
|
|
88
|
+
match self._processor:
|
|
89
|
+
case dict() | DictConfig() if "_target_" in self._processor:
|
|
90
|
+
self._processor_instance = instantiate(self._processor)
|
|
91
|
+
return self._processor_instance
|
|
92
|
+
case str():
|
|
93
|
+
from transformers import AutoProcessor
|
|
94
|
+
|
|
95
|
+
self._processor_instance = AutoProcessor.from_pretrained(
|
|
96
|
+
self._processor
|
|
97
|
+
)
|
|
98
|
+
return self._processor_instance
|
|
99
|
+
|
|
100
|
+
raise ValueError("Processor is not properly configured.")
|
|
101
|
+
|
|
102
|
+
def setup(self):
|
|
103
|
+
# Load test datasets
|
|
104
|
+
test_datasets = {
|
|
105
|
+
ds_name: ImageClassificationDataset(
|
|
106
|
+
self.load_test_dataset(ds_name), processor=self.processor
|
|
107
|
+
)
|
|
108
|
+
for ds_name in self._test_datasets
|
|
109
|
+
}
|
|
110
|
+
self.test_datasets = test_datasets
|
|
111
|
+
self.test_dataloaders = {
|
|
112
|
+
ds_name: self.fabric.setup_dataloaders(
|
|
113
|
+
self.get_dataloader(ds, stage="test")
|
|
114
|
+
)
|
|
115
|
+
for ds_name, ds in test_datasets.items()
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
|
|
119
|
+
"""
|
|
120
|
+
Load the testing dataset for the specified model.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
dataset_name (str): The name of the model.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Dataset: The instantiated testing dataset.
|
|
127
|
+
"""
|
|
128
|
+
test_dataset = self._test_datasets[dataset_name]
|
|
129
|
+
if isinstance(test_dataset, (DictConfig, dict)):
|
|
130
|
+
return instantiate(test_dataset, *args, **kwargs)
|
|
131
|
+
else:
|
|
132
|
+
return test_dataset
|
|
133
|
+
|
|
134
|
+
def get_dataloader(self, dataset, stage: str):
|
|
135
|
+
"""Create a DataLoader for the specified dataset and training stage.
|
|
136
|
+
|
|
137
|
+
Constructs a PyTorch DataLoader with stage-appropriate configurations:
|
|
138
|
+
- Training stage: shuffling enabled by default
|
|
139
|
+
- Validation/test stages: shuffling disabled by default
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
dataset: The dataset to wrap in a DataLoader.
|
|
143
|
+
stage (str): Training stage, must be one of "train", "val", or "test".
|
|
144
|
+
Determines default shuffling behavior.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
DataLoader: Configured DataLoader for the given dataset and stage.
|
|
148
|
+
"""
|
|
149
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
150
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
151
|
+
if "shuffle" not in dataloader_kwargs:
|
|
152
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
153
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
|
154
|
+
|
|
155
|
+
@torch.no_grad()
|
|
156
|
+
def _evaluate(
|
|
157
|
+
self,
|
|
158
|
+
classifier,
|
|
159
|
+
test_loader,
|
|
160
|
+
num_classes: int,
|
|
161
|
+
task_name: str = None,
|
|
162
|
+
):
|
|
163
|
+
classifier.eval()
|
|
164
|
+
accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
|
165
|
+
loss_metric = MeanMetric()
|
|
166
|
+
if RuntimeConstants.debug:
|
|
167
|
+
log.info("Running under fast_dev_run mode, evaluating on a single batch.")
|
|
168
|
+
test_loader = itertools.islice(test_loader, 1)
|
|
169
|
+
else:
|
|
170
|
+
test_loader = test_loader
|
|
171
|
+
|
|
172
|
+
pbar = tqdm(
|
|
173
|
+
test_loader,
|
|
174
|
+
desc=f"Evaluating {task_name}" if task_name is not None else "Evaluating",
|
|
175
|
+
leave=False,
|
|
176
|
+
dynamic_ncols=True,
|
|
177
|
+
)
|
|
178
|
+
for batch in pbar:
|
|
179
|
+
inputs, targets = batch
|
|
180
|
+
outputs = classifier(inputs)
|
|
181
|
+
logits = _get_logits_from_model_output(outputs)
|
|
182
|
+
if logits.device != targets.device:
|
|
183
|
+
targets = targets.to(logits.device)
|
|
184
|
+
|
|
185
|
+
loss = F.cross_entropy(logits, targets)
|
|
186
|
+
loss_metric.update(loss.detach().cpu())
|
|
187
|
+
acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
|
|
188
|
+
pbar.set_postfix(
|
|
189
|
+
{
|
|
190
|
+
"accuracy": accuracy.compute().item(),
|
|
191
|
+
"loss": loss_metric.compute().item(),
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
acc = accuracy.compute().item()
|
|
196
|
+
loss = loss_metric.compute().item()
|
|
197
|
+
results = {"accuracy": acc, "loss": loss}
|
|
198
|
+
return results
|
|
199
|
+
|
|
200
|
+
def evaluate(
|
|
201
|
+
self,
|
|
202
|
+
model: Union["AutoModelForImageClassification", nn.Module],
|
|
203
|
+
name: str = None,
|
|
204
|
+
**kwargs,
|
|
205
|
+
) -> Dict[str, Any]:
|
|
206
|
+
assert isinstance(
|
|
207
|
+
model, nn.Module
|
|
208
|
+
), f"Expected model to be an instance of nn.Module, but got {type(model)}"
|
|
209
|
+
|
|
210
|
+
if not self._is_setup:
|
|
211
|
+
self.setup()
|
|
212
|
+
|
|
213
|
+
classifier = self.fabric.to_device(model)
|
|
214
|
+
classifier.eval()
|
|
215
|
+
report = {}
|
|
216
|
+
# collect basic model information
|
|
217
|
+
training_params, all_params = count_parameters(model)
|
|
218
|
+
report["model_info"] = {
|
|
219
|
+
"trainable_params": training_params,
|
|
220
|
+
"all_params": all_params,
|
|
221
|
+
"trainable_percentage": training_params / all_params,
|
|
222
|
+
}
|
|
223
|
+
if name is not None:
|
|
224
|
+
report["model_info"]["name"] = name
|
|
225
|
+
|
|
226
|
+
# evaluate on each task
|
|
227
|
+
pbar = tqdm(
|
|
228
|
+
self.test_dataloaders.items(),
|
|
229
|
+
desc="Evaluating tasks",
|
|
230
|
+
total=len(self.test_dataloaders),
|
|
231
|
+
)
|
|
232
|
+
for task_name, test_dataloader in pbar:
|
|
233
|
+
set_active_option(classifier, task_name)
|
|
234
|
+
num_classes = get_num_classes(task_name)
|
|
235
|
+
result = self._evaluate(
|
|
236
|
+
classifier,
|
|
237
|
+
test_dataloader,
|
|
238
|
+
num_classes=num_classes,
|
|
239
|
+
task_name=task_name,
|
|
240
|
+
)
|
|
241
|
+
report[task_name] = result
|
|
242
|
+
|
|
243
|
+
# calculate the average accuracy and loss
|
|
244
|
+
if "average" not in report:
|
|
245
|
+
report["average"] = {}
|
|
246
|
+
accuracies = [
|
|
247
|
+
value["accuracy"]
|
|
248
|
+
for key, value in report.items()
|
|
249
|
+
if "accuracy" in value
|
|
250
|
+
]
|
|
251
|
+
if len(accuracies) > 0:
|
|
252
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
253
|
+
report["average"]["accuracy"] = average_accuracy
|
|
254
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
255
|
+
if len(losses) > 0:
|
|
256
|
+
average_loss = sum(losses) / len(losses)
|
|
257
|
+
report["average"]["loss"] = average_loss
|
|
258
|
+
|
|
259
|
+
log.info(f"Evaluation Result: {report}")
|
|
260
|
+
if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
|
|
261
|
+
save_path = os.path.join(self.log_dir, "report.json")
|
|
262
|
+
for version in itertools.count(1):
|
|
263
|
+
if not os.path.exists(save_path):
|
|
264
|
+
break
|
|
265
|
+
# if the file already exists, increment the version to avoid overwriting
|
|
266
|
+
save_path = os.path.join(self.log_dir, f"report_{version}.json")
|
|
267
|
+
with open(save_path, "w") as fp:
|
|
268
|
+
json.dump(report, fp)
|
|
269
|
+
log.info(f"Evaluation report saved to {save_path}")
|
|
270
|
+
return report
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -31,6 +31,11 @@ _import_structure = {
|
|
|
31
31
|
],
|
|
32
32
|
"dtype": ["get_dtype", "parse_dtype"],
|
|
33
33
|
"fabric": ["seed_everything_by_time"],
|
|
34
|
+
"hydra_utils": [
|
|
35
|
+
"initialize_hydra_config",
|
|
36
|
+
"get_default_config_path",
|
|
37
|
+
"get_hydra_output_dir",
|
|
38
|
+
],
|
|
34
39
|
"instantiate_utils": [
|
|
35
40
|
"instantiate",
|
|
36
41
|
"is_instantiable",
|
|
@@ -40,6 +45,7 @@ _import_structure = {
|
|
|
40
45
|
"json": ["load_from_json", "save_to_json", "print_json"],
|
|
41
46
|
"lazy_state_dict": ["LazyStateDict"],
|
|
42
47
|
"misc": [
|
|
48
|
+
"DeprecationWarningMeta",
|
|
43
49
|
"first",
|
|
44
50
|
"has_length",
|
|
45
51
|
"join_lists",
|
|
@@ -122,6 +128,11 @@ if TYPE_CHECKING:
|
|
|
122
128
|
)
|
|
123
129
|
from .dtype import get_dtype, parse_dtype
|
|
124
130
|
from .fabric import seed_everything_by_time
|
|
131
|
+
from .hydra_utils import (
|
|
132
|
+
get_default_config_path,
|
|
133
|
+
get_hydra_output_dir,
|
|
134
|
+
initialize_hydra_config,
|
|
135
|
+
)
|
|
125
136
|
from .instantiate_utils import (
|
|
126
137
|
instantiate,
|
|
127
138
|
is_instantiable,
|
|
@@ -130,7 +141,13 @@ if TYPE_CHECKING:
|
|
|
130
141
|
)
|
|
131
142
|
from .json import load_from_json, print_json, save_to_json
|
|
132
143
|
from .lazy_state_dict import LazyStateDict
|
|
133
|
-
from .misc import
|
|
144
|
+
from .misc import (
|
|
145
|
+
DeprecationWarningMeta,
|
|
146
|
+
first,
|
|
147
|
+
has_length,
|
|
148
|
+
join_lists,
|
|
149
|
+
validate_and_suggest_corrections,
|
|
150
|
+
)
|
|
134
151
|
from .packages import compare_versions, import_object
|
|
135
152
|
from .parameters import (
|
|
136
153
|
check_parameters_all_equal,
|
fusion_bench/utils/data.py
CHANGED
|
@@ -95,7 +95,7 @@ class InfiniteDataLoader:
|
|
|
95
95
|
f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
|
|
96
96
|
f"Last error: [{type(last_exception).__name__}]{last_exception}. "
|
|
97
97
|
+ (
|
|
98
|
-
f"The data loader
|
|
98
|
+
f"The data loader may be empty."
|
|
99
99
|
if isinstance(last_exception, StopIteration)
|
|
100
100
|
else ""
|
|
101
101
|
)
|
fusion_bench/utils/dict.py
CHANGED
|
@@ -41,3 +41,22 @@ def dict_map(f, d: dict, *, max_level: int = -1, skip_levels=0, inplace=False):
|
|
|
41
41
|
|
|
42
42
|
dict_map_impl(d, ans, 0)
|
|
43
43
|
return ans
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def dict_merge(dicts: Iterable[dict], disjoint: bool = True) -> dict:
|
|
47
|
+
"""Merge multiple dictionaries into one.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
dicts (Iterable[dict]): iterable of dictionaries to merge
|
|
51
|
+
disjoint (bool, optional): if True, raises an error on key conflicts. Defaults to True.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
dict: merged dictionary
|
|
55
|
+
"""
|
|
56
|
+
merged_dict = type(dicts[0])()
|
|
57
|
+
for d in dicts:
|
|
58
|
+
for k, v in d.items():
|
|
59
|
+
if disjoint and k in merged_dict:
|
|
60
|
+
raise ValueError(f"Key conflict when merging dictionaries: {k}")
|
|
61
|
+
merged_dict[k] = v
|
|
62
|
+
return merged_dict
|
fusion_bench/utils/dtype.py
CHANGED
|
@@ -146,3 +146,22 @@ def validate_expected_param_dtype(
|
|
|
146
146
|
raise ValueError(
|
|
147
147
|
f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
|
|
148
148
|
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def dtype_support_svd(dtype: torch.dtype) -> bool:
|
|
152
|
+
"""
|
|
153
|
+
Check if the given dtype is supported for SVD operation in PyTorch.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
dtype (torch.dtype): The data type to check.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
bool: True if the dtype is supported for SVD, False otherwise.
|
|
160
|
+
"""
|
|
161
|
+
supported_dtypes = {
|
|
162
|
+
torch.float32,
|
|
163
|
+
torch.float64,
|
|
164
|
+
torch.complex64,
|
|
165
|
+
torch.complex128,
|
|
166
|
+
}
|
|
167
|
+
return dtype in supported_dtypes
|
fusion_bench/utils/misc.py
CHANGED
fusion_bench/utils/packages.py
CHANGED
|
@@ -1,6 +1,16 @@
|
|
|
1
1
|
from collections import OrderedDict
|
|
2
2
|
from numbers import Number
|
|
3
|
-
from typing import
|
|
3
|
+
from typing import (
|
|
4
|
+
Callable,
|
|
5
|
+
Dict,
|
|
6
|
+
Iterator,
|
|
7
|
+
List,
|
|
8
|
+
Literal,
|
|
9
|
+
Mapping,
|
|
10
|
+
Optional,
|
|
11
|
+
Union,
|
|
12
|
+
cast,
|
|
13
|
+
)
|
|
4
14
|
|
|
5
15
|
import torch
|
|
6
16
|
from torch import Tensor
|
|
@@ -462,6 +472,118 @@ class ArithmeticStateDict(OrderedDict):
|
|
|
462
472
|
return cls(result_dict)
|
|
463
473
|
|
|
464
474
|
|
|
475
|
+
class LazyStateDictExpr(Mapping[str, torch.Tensor]):
|
|
476
|
+
"""
|
|
477
|
+
A lazy, key-wise expression over state_dict-like objects.
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
# ---- core Mapping API ----
|
|
481
|
+
def __getitem__(self, key: str) -> torch.Tensor:
|
|
482
|
+
raise NotImplementedError
|
|
483
|
+
|
|
484
|
+
def __iter__(self) -> Iterator[str]:
|
|
485
|
+
raise NotImplementedError
|
|
486
|
+
|
|
487
|
+
def __len__(self) -> int:
|
|
488
|
+
raise NotImplementedError
|
|
489
|
+
|
|
490
|
+
# ---- arithmetic (build graph only) ----
|
|
491
|
+
def __add__(self, other):
|
|
492
|
+
return BinaryOp(torch.add, self, ensure_expr(other))
|
|
493
|
+
|
|
494
|
+
def __sub__(self, other):
|
|
495
|
+
return BinaryOp(torch.sub, self, ensure_expr(other))
|
|
496
|
+
|
|
497
|
+
def __mul__(self, scalar):
|
|
498
|
+
return UnaryOp(lambda x: x * scalar, self)
|
|
499
|
+
|
|
500
|
+
def __rmul__(self, scalar):
|
|
501
|
+
return self.__mul__(scalar)
|
|
502
|
+
|
|
503
|
+
def __truediv__(self, scalar):
|
|
504
|
+
return UnaryOp(lambda x: x / scalar, self)
|
|
505
|
+
|
|
506
|
+
# ---- eager escape hatch ----
|
|
507
|
+
def materialize(
|
|
508
|
+
self, device=None, dtype=None, non_blocking=False, copy=False
|
|
509
|
+
) -> Dict[str, torch.Tensor]:
|
|
510
|
+
"""
|
|
511
|
+
Eagerly evaluate into an OrderedDict.
|
|
512
|
+
"""
|
|
513
|
+
out = {}
|
|
514
|
+
for k in self:
|
|
515
|
+
v = self[k]
|
|
516
|
+
out[k] = v.to(
|
|
517
|
+
device=device,
|
|
518
|
+
dtype=dtype,
|
|
519
|
+
non_blocking=non_blocking,
|
|
520
|
+
copy=copy,
|
|
521
|
+
)
|
|
522
|
+
return out
|
|
523
|
+
|
|
524
|
+
def __repr__(self):
|
|
525
|
+
return f"{self.__class__.__name__}(lazy)"
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
class StateDictLeaf(LazyStateDictExpr):
|
|
529
|
+
def __init__(self, state_dict: Mapping[str, torch.Tensor]):
|
|
530
|
+
self._sd = state_dict
|
|
531
|
+
|
|
532
|
+
def __getitem__(self, key: str) -> torch.Tensor:
|
|
533
|
+
return self._sd[key]
|
|
534
|
+
|
|
535
|
+
def __iter__(self):
|
|
536
|
+
return iter(self._sd)
|
|
537
|
+
|
|
538
|
+
def __len__(self):
|
|
539
|
+
return len(self._sd)
|
|
540
|
+
|
|
541
|
+
|
|
542
|
+
class UnaryOp(LazyStateDictExpr):
|
|
543
|
+
def __init__(self, op: Callable[[torch.Tensor], torch.Tensor], child):
|
|
544
|
+
self.op = op
|
|
545
|
+
self.child = child
|
|
546
|
+
|
|
547
|
+
def __getitem__(self, key: str):
|
|
548
|
+
return self.op(self.child[key])
|
|
549
|
+
|
|
550
|
+
def __iter__(self):
|
|
551
|
+
return iter(self.child)
|
|
552
|
+
|
|
553
|
+
def __len__(self):
|
|
554
|
+
return len(self.child)
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
class BinaryOp(LazyStateDictExpr):
|
|
558
|
+
def __init__(
|
|
559
|
+
self,
|
|
560
|
+
op: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
|
561
|
+
left,
|
|
562
|
+
right,
|
|
563
|
+
):
|
|
564
|
+
self.op = op
|
|
565
|
+
self.left = left
|
|
566
|
+
self.right = right
|
|
567
|
+
|
|
568
|
+
def __getitem__(self, key: str):
|
|
569
|
+
return self.op(self.left[key], self.right[key])
|
|
570
|
+
|
|
571
|
+
def __iter__(self):
|
|
572
|
+
# assume key sets are aligned
|
|
573
|
+
return iter(self.left)
|
|
574
|
+
|
|
575
|
+
def __len__(self):
|
|
576
|
+
return len(self.left)
|
|
577
|
+
|
|
578
|
+
|
|
579
|
+
def ensure_expr(x):
|
|
580
|
+
if isinstance(x, LazyStateDictExpr):
|
|
581
|
+
return x
|
|
582
|
+
if isinstance(x, Mapping):
|
|
583
|
+
return StateDictLeaf(x)
|
|
584
|
+
raise TypeError(f"Unsupported operand type: {type(x)}")
|
|
585
|
+
|
|
586
|
+
|
|
465
587
|
def _validate_state_dict_list_not_empty(state_dicts: List[StateDictType]) -> None:
|
|
466
588
|
"""
|
|
467
589
|
Validate that the list of state dicts is not empty and contains valid state dicts.
|
|
@@ -1228,3 +1350,63 @@ def state_dict_hadamard_product(a: StateDictType, b: StateDictType) -> StateDict
|
|
|
1228
1350
|
"""
|
|
1229
1351
|
_validate_state_dict_same_keys([a, b])
|
|
1230
1352
|
return OrderedDict((key, a[key] * b[key]) for key in a)
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
def state_dict_max(
|
|
1356
|
+
state_dicts: List[StateDictType],
|
|
1357
|
+
) -> StateDictType:
|
|
1358
|
+
"""
|
|
1359
|
+
Compute the element-wise maximum across multiple state dicts.
|
|
1360
|
+
|
|
1361
|
+
Args:
|
|
1362
|
+
state_dicts: List of state dicts to compute the maximum from.
|
|
1363
|
+
|
|
1364
|
+
Returns:
|
|
1365
|
+
A state dict containing the element-wise maximums.
|
|
1366
|
+
"""
|
|
1367
|
+
_validate_state_dict_list_not_empty(state_dicts)
|
|
1368
|
+
_validate_state_dict_same_keys(state_dicts)
|
|
1369
|
+
|
|
1370
|
+
max_state_dict = OrderedDict()
|
|
1371
|
+
|
|
1372
|
+
for key in state_dicts[0]:
|
|
1373
|
+
# Initialize with the first tensor
|
|
1374
|
+
max_tensor = state_dicts[0][key].clone()
|
|
1375
|
+
|
|
1376
|
+
# Compute element-wise maximum
|
|
1377
|
+
for state_dict in state_dicts[1:]:
|
|
1378
|
+
max_tensor = torch.max(max_tensor, state_dict[key])
|
|
1379
|
+
|
|
1380
|
+
max_state_dict[key] = max_tensor
|
|
1381
|
+
|
|
1382
|
+
return max_state_dict
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def state_dict_max_abs(
|
|
1386
|
+
state_dicts: List[StateDictType],
|
|
1387
|
+
) -> StateDictType:
|
|
1388
|
+
"""
|
|
1389
|
+
Compute the element-wise maximum absolute value across multiple state dicts.
|
|
1390
|
+
|
|
1391
|
+
Args:
|
|
1392
|
+
state_dicts: List of state dicts to compute the maximum absolute values from.
|
|
1393
|
+
|
|
1394
|
+
Returns:
|
|
1395
|
+
A state dict containing the element-wise maximum absolute values.
|
|
1396
|
+
"""
|
|
1397
|
+
_validate_state_dict_list_not_empty(state_dicts)
|
|
1398
|
+
_validate_state_dict_same_keys(state_dicts)
|
|
1399
|
+
|
|
1400
|
+
max_abs_state_dict = OrderedDict()
|
|
1401
|
+
|
|
1402
|
+
for key in state_dicts[0]:
|
|
1403
|
+
# Initialize with the absolute values of the first tensor
|
|
1404
|
+
max_abs_tensor = state_dicts[0][key].abs()
|
|
1405
|
+
|
|
1406
|
+
# Compute element-wise maximum absolute value
|
|
1407
|
+
for state_dict in state_dicts[1:]:
|
|
1408
|
+
max_abs_tensor = torch.max(max_abs_tensor, state_dict[key].abs())
|
|
1409
|
+
|
|
1410
|
+
max_abs_state_dict[key] = max_abs_tensor
|
|
1411
|
+
|
|
1412
|
+
return max_abs_state_dict
|
|
@@ -2,14 +2,18 @@
|
|
|
2
2
|
functions deal with tensorboard logs.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Dict, Iterable, List, Union
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import pandas as pd
|
|
9
10
|
from tensorboard.backend.event_processing import event_accumulator
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def parse_tensorboard_as_dict(
|
|
13
|
+
def parse_tensorboard_as_dict(
|
|
14
|
+
path: Union[str, Path],
|
|
15
|
+
scalars: Iterable[str],
|
|
16
|
+
) -> Dict[str, pd.DataFrame]:
|
|
13
17
|
"""
|
|
14
18
|
returns a dictionary of pandas dataframes for each requested scalar.
|
|
15
19
|
|
|
@@ -20,7 +24,19 @@ def parse_tensorboard_as_dict(path: str, scalars: Iterable[str]):
|
|
|
20
24
|
|
|
21
25
|
Returns:
|
|
22
26
|
Dict[str, pandas.DataFrame]: a dictionary of pandas dataframes for each requested scalar
|
|
27
|
+
|
|
28
|
+
Example:
|
|
29
|
+
|
|
30
|
+
>>> from fusion_bench.utils.tensorboard import parse_tensorboard_as_dict
|
|
31
|
+
>>> path = "path/to/tensorboard/logs"
|
|
32
|
+
>>> scalars = ["train/loss", "val/accuracy"]
|
|
33
|
+
>>> data = parse_tensorboard_as_dict(path, scalars)
|
|
34
|
+
>>> train_loss_df = data["train/loss"]
|
|
35
|
+
>>> val_accuracy_df = data["val/accuracy"]
|
|
23
36
|
"""
|
|
37
|
+
if isinstance(path, Path):
|
|
38
|
+
path = str(path)
|
|
39
|
+
assert isinstance(path, str), "path must be a string"
|
|
24
40
|
ea = event_accumulator.EventAccumulator(
|
|
25
41
|
path,
|
|
26
42
|
size_guidance={event_accumulator.SCALARS: 0},
|
|
@@ -33,7 +49,9 @@ def parse_tensorboard_as_dict(path: str, scalars: Iterable[str]):
|
|
|
33
49
|
return {k: pd.DataFrame(ea.Scalars(k)) for k in scalars}
|
|
34
50
|
|
|
35
51
|
|
|
36
|
-
def parse_tensorboard_as_list(
|
|
52
|
+
def parse_tensorboard_as_list(
|
|
53
|
+
path: Union[str, Path], scalars: Iterable[str]
|
|
54
|
+
) -> List[pd.DataFrame]:
|
|
37
55
|
"""
|
|
38
56
|
returns a list of pandas dataframes for each requested scalar.
|
|
39
57
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion-bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.32
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
Project-URL: Repository, https://github.com/tanganke/fusion_bench
|
|
@@ -61,6 +61,8 @@ Dynamic: license-file
|
|
|
61
61
|
|
|
62
62
|
FusionBench is a benchmark suite designed to evaluate the performance of various deep model fusion techniques. It aims to provide a comprehensive comparison of different methods on a variety of datasets and tasks.
|
|
63
63
|
|
|
64
|
+
## :newspaper: News and Related
|
|
65
|
+
|
|
64
66
|
Projects based on FusionBench and news from the community (descending order of date. If you have any work based on FusionBench, please feel free to let us know, we are willing to add it to the list. :partying_face:):
|
|
65
67
|
|
|
66
68
|
<details>
|