fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/mixins/lightning_fabric.py +110 -11
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -0
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/open_clip/modeling.py +61 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +25 -23
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +20 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +42 -19
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig, OmegaConf
|
|
8
|
+
from torch import Tensor, nn
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
from torch.utils.data import DataLoader, Dataset
|
|
11
|
+
from torchmetrics import Accuracy, MeanMetric
|
|
12
|
+
from tqdm.auto import tqdm
|
|
13
|
+
|
|
14
|
+
from fusion_bench import (
|
|
15
|
+
BaseTaskPool,
|
|
16
|
+
LightningFabricMixin,
|
|
17
|
+
RuntimeConstants,
|
|
18
|
+
auto_register_config,
|
|
19
|
+
get_rankzero_logger,
|
|
20
|
+
instantiate,
|
|
21
|
+
)
|
|
22
|
+
from fusion_bench.dataset import ImageClassificationDataset
|
|
23
|
+
from fusion_bench.models.wrappers.switch import set_active_option
|
|
24
|
+
from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
|
|
25
|
+
from fusion_bench.utils import count_parameters
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from transformers import AutoModelForImageClassification
|
|
29
|
+
|
|
30
|
+
log = get_rankzero_logger(__name__)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _get_logits_from_model_output(outputs) -> Tensor:
|
|
34
|
+
"""Extract logits from model outputs."""
|
|
35
|
+
match outputs:
|
|
36
|
+
case Tensor():
|
|
37
|
+
logits = outputs
|
|
38
|
+
case dict() | DictConfig() if "logits" in outputs:
|
|
39
|
+
logits = outputs["logits"]
|
|
40
|
+
assert isinstance(
|
|
41
|
+
logits, Tensor
|
|
42
|
+
), "The 'logits' in the model output dictionary is not a Tensor."
|
|
43
|
+
case _:
|
|
44
|
+
if hasattr(outputs, "logits"):
|
|
45
|
+
logits = outputs.logits
|
|
46
|
+
assert isinstance(
|
|
47
|
+
logits, Tensor
|
|
48
|
+
), "The 'logits' attribute of the model output is not a Tensor."
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"Model output is not a Tensor and does not have 'logits' attribute."
|
|
52
|
+
)
|
|
53
|
+
return logits
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@auto_register_config
|
|
57
|
+
class ImageClassificationTaskPool(LightningFabricMixin, BaseTaskPool):
|
|
58
|
+
_config_mapping = BaseTaskPool._config_mapping | {
|
|
59
|
+
"_test_datasets": "test_datasets",
|
|
60
|
+
"_processor": "processor",
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
_processor_instance = None
|
|
64
|
+
_is_setup: bool = False
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
test_datasets: DictConfig | Dict[str, Any],
|
|
69
|
+
*,
|
|
70
|
+
processor: DictConfig | Any,
|
|
71
|
+
dataloader_kwargs: DictConfig,
|
|
72
|
+
**kwargs,
|
|
73
|
+
):
|
|
74
|
+
super().__init__(**kwargs)
|
|
75
|
+
|
|
76
|
+
# if the processor is given as a transformers processor instance, store it directly
|
|
77
|
+
if callable(processor):
|
|
78
|
+
self._processor_instance = processor
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def processor(self) -> Any:
|
|
82
|
+
if self._processor is None:
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
if self._processor_instance is not None:
|
|
86
|
+
return self._processor_instance
|
|
87
|
+
|
|
88
|
+
match self._processor:
|
|
89
|
+
case dict() | DictConfig() if "_target_" in self._processor:
|
|
90
|
+
self._processor_instance = instantiate(self._processor)
|
|
91
|
+
return self._processor_instance
|
|
92
|
+
case str():
|
|
93
|
+
from transformers import AutoProcessor
|
|
94
|
+
|
|
95
|
+
self._processor_instance = AutoProcessor.from_pretrained(
|
|
96
|
+
self._processor
|
|
97
|
+
)
|
|
98
|
+
return self._processor_instance
|
|
99
|
+
|
|
100
|
+
raise ValueError("Processor is not properly configured.")
|
|
101
|
+
|
|
102
|
+
def setup(self):
|
|
103
|
+
# Load test datasets
|
|
104
|
+
test_datasets = {
|
|
105
|
+
ds_name: ImageClassificationDataset(
|
|
106
|
+
self.load_test_dataset(ds_name), processor=self.processor
|
|
107
|
+
)
|
|
108
|
+
for ds_name in self._test_datasets
|
|
109
|
+
}
|
|
110
|
+
self.test_datasets = test_datasets
|
|
111
|
+
self.test_dataloaders = {
|
|
112
|
+
ds_name: self.fabric.setup_dataloaders(
|
|
113
|
+
self.get_dataloader(ds, stage="test")
|
|
114
|
+
)
|
|
115
|
+
for ds_name, ds in test_datasets.items()
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
def load_test_dataset(self, dataset_name: str, *args, **kwargs) -> Dataset:
|
|
119
|
+
"""
|
|
120
|
+
Load the testing dataset for the specified model.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
dataset_name (str): The name of the model.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Dataset: The instantiated testing dataset.
|
|
127
|
+
"""
|
|
128
|
+
test_dataset = self._test_datasets[dataset_name]
|
|
129
|
+
if isinstance(test_dataset, (DictConfig, dict)):
|
|
130
|
+
return instantiate(test_dataset, *args, **kwargs)
|
|
131
|
+
else:
|
|
132
|
+
return test_dataset
|
|
133
|
+
|
|
134
|
+
def get_dataloader(self, dataset, stage: str):
|
|
135
|
+
"""Create a DataLoader for the specified dataset and training stage.
|
|
136
|
+
|
|
137
|
+
Constructs a PyTorch DataLoader with stage-appropriate configurations:
|
|
138
|
+
- Training stage: shuffling enabled by default
|
|
139
|
+
- Validation/test stages: shuffling disabled by default
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
dataset: The dataset to wrap in a DataLoader.
|
|
143
|
+
stage (str): Training stage, must be one of "train", "val", or "test".
|
|
144
|
+
Determines default shuffling behavior.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
DataLoader: Configured DataLoader for the given dataset and stage.
|
|
148
|
+
"""
|
|
149
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
150
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
151
|
+
if "shuffle" not in dataloader_kwargs:
|
|
152
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
153
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
|
154
|
+
|
|
155
|
+
@torch.no_grad()
|
|
156
|
+
def _evaluate(
|
|
157
|
+
self,
|
|
158
|
+
classifier,
|
|
159
|
+
test_loader,
|
|
160
|
+
num_classes: int,
|
|
161
|
+
task_name: str = None,
|
|
162
|
+
):
|
|
163
|
+
classifier.eval()
|
|
164
|
+
accuracy = Accuracy(task="multiclass", num_classes=num_classes)
|
|
165
|
+
loss_metric = MeanMetric()
|
|
166
|
+
if RuntimeConstants.debug:
|
|
167
|
+
log.info("Running under fast_dev_run mode, evaluating on a single batch.")
|
|
168
|
+
test_loader = itertools.islice(test_loader, 1)
|
|
169
|
+
else:
|
|
170
|
+
test_loader = test_loader
|
|
171
|
+
|
|
172
|
+
pbar = tqdm(
|
|
173
|
+
test_loader,
|
|
174
|
+
desc=f"Evaluating {task_name}" if task_name is not None else "Evaluating",
|
|
175
|
+
leave=False,
|
|
176
|
+
dynamic_ncols=True,
|
|
177
|
+
)
|
|
178
|
+
for batch in pbar:
|
|
179
|
+
inputs, targets = batch
|
|
180
|
+
outputs = classifier(inputs)
|
|
181
|
+
logits = _get_logits_from_model_output(outputs)
|
|
182
|
+
if logits.device != targets.device:
|
|
183
|
+
targets = targets.to(logits.device)
|
|
184
|
+
|
|
185
|
+
loss = F.cross_entropy(logits, targets)
|
|
186
|
+
loss_metric.update(loss.detach().cpu())
|
|
187
|
+
acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
|
|
188
|
+
pbar.set_postfix(
|
|
189
|
+
{
|
|
190
|
+
"accuracy": accuracy.compute().item(),
|
|
191
|
+
"loss": loss_metric.compute().item(),
|
|
192
|
+
}
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
acc = accuracy.compute().item()
|
|
196
|
+
loss = loss_metric.compute().item()
|
|
197
|
+
results = {"accuracy": acc, "loss": loss}
|
|
198
|
+
return results
|
|
199
|
+
|
|
200
|
+
def evaluate(
|
|
201
|
+
self,
|
|
202
|
+
model: Union["AutoModelForImageClassification", nn.Module],
|
|
203
|
+
name: str = None,
|
|
204
|
+
**kwargs,
|
|
205
|
+
) -> Dict[str, Any]:
|
|
206
|
+
assert isinstance(
|
|
207
|
+
model, nn.Module
|
|
208
|
+
), f"Expected model to be an instance of nn.Module, but got {type(model)}"
|
|
209
|
+
|
|
210
|
+
if not self._is_setup:
|
|
211
|
+
self.setup()
|
|
212
|
+
|
|
213
|
+
classifier = self.fabric.to_device(model)
|
|
214
|
+
classifier.eval()
|
|
215
|
+
report = {}
|
|
216
|
+
# collect basic model information
|
|
217
|
+
training_params, all_params = count_parameters(model)
|
|
218
|
+
report["model_info"] = {
|
|
219
|
+
"trainable_params": training_params,
|
|
220
|
+
"all_params": all_params,
|
|
221
|
+
"trainable_percentage": training_params / all_params,
|
|
222
|
+
}
|
|
223
|
+
if name is not None:
|
|
224
|
+
report["model_info"]["name"] = name
|
|
225
|
+
|
|
226
|
+
# evaluate on each task
|
|
227
|
+
pbar = tqdm(
|
|
228
|
+
self.test_dataloaders.items(),
|
|
229
|
+
desc="Evaluating tasks",
|
|
230
|
+
total=len(self.test_dataloaders),
|
|
231
|
+
)
|
|
232
|
+
for task_name, test_dataloader in pbar:
|
|
233
|
+
set_active_option(classifier, task_name)
|
|
234
|
+
num_classes = get_num_classes(task_name)
|
|
235
|
+
result = self._evaluate(
|
|
236
|
+
classifier,
|
|
237
|
+
test_dataloader,
|
|
238
|
+
num_classes=num_classes,
|
|
239
|
+
task_name=task_name,
|
|
240
|
+
)
|
|
241
|
+
report[task_name] = result
|
|
242
|
+
|
|
243
|
+
# calculate the average accuracy and loss
|
|
244
|
+
if "average" not in report:
|
|
245
|
+
report["average"] = {}
|
|
246
|
+
accuracies = [
|
|
247
|
+
value["accuracy"]
|
|
248
|
+
for key, value in report.items()
|
|
249
|
+
if "accuracy" in value
|
|
250
|
+
]
|
|
251
|
+
if len(accuracies) > 0:
|
|
252
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
253
|
+
report["average"]["accuracy"] = average_accuracy
|
|
254
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
255
|
+
if len(losses) > 0:
|
|
256
|
+
average_loss = sum(losses) / len(losses)
|
|
257
|
+
report["average"]["loss"] = average_loss
|
|
258
|
+
|
|
259
|
+
log.info(f"Evaluation Result: {report}")
|
|
260
|
+
if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
|
|
261
|
+
save_path = os.path.join(self.log_dir, "report.json")
|
|
262
|
+
for version in itertools.count(1):
|
|
263
|
+
if not os.path.exists(save_path):
|
|
264
|
+
break
|
|
265
|
+
# if the file already exists, increment the version to avoid overwriting
|
|
266
|
+
save_path = os.path.join(self.log_dir, f"report_{version}.json")
|
|
267
|
+
with open(save_path, "w") as fp:
|
|
268
|
+
json.dump(report, fp)
|
|
269
|
+
log.info(f"Evaluation report saved to {save_path}")
|
|
270
|
+
return report
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -31,6 +31,11 @@ _import_structure = {
|
|
|
31
31
|
],
|
|
32
32
|
"dtype": ["get_dtype", "parse_dtype"],
|
|
33
33
|
"fabric": ["seed_everything_by_time"],
|
|
34
|
+
"hydra_utils": [
|
|
35
|
+
"initialize_hydra_config",
|
|
36
|
+
"get_default_config_path",
|
|
37
|
+
"get_hydra_output_dir",
|
|
38
|
+
],
|
|
34
39
|
"instantiate_utils": [
|
|
35
40
|
"instantiate",
|
|
36
41
|
"is_instantiable",
|
|
@@ -40,6 +45,7 @@ _import_structure = {
|
|
|
40
45
|
"json": ["load_from_json", "save_to_json", "print_json"],
|
|
41
46
|
"lazy_state_dict": ["LazyStateDict"],
|
|
42
47
|
"misc": [
|
|
48
|
+
"DeprecationWarningMeta",
|
|
43
49
|
"first",
|
|
44
50
|
"has_length",
|
|
45
51
|
"join_lists",
|
|
@@ -53,6 +59,7 @@ _import_structure = {
|
|
|
53
59
|
"get_parameter_summary",
|
|
54
60
|
"human_readable",
|
|
55
61
|
"print_parameters",
|
|
62
|
+
"print_trainable_parameters",
|
|
56
63
|
"state_dict_to_vector",
|
|
57
64
|
"trainable_state_dict",
|
|
58
65
|
"vector_to_state_dict",
|
|
@@ -121,6 +128,11 @@ if TYPE_CHECKING:
|
|
|
121
128
|
)
|
|
122
129
|
from .dtype import get_dtype, parse_dtype
|
|
123
130
|
from .fabric import seed_everything_by_time
|
|
131
|
+
from .hydra_utils import (
|
|
132
|
+
get_default_config_path,
|
|
133
|
+
get_hydra_output_dir,
|
|
134
|
+
initialize_hydra_config,
|
|
135
|
+
)
|
|
124
136
|
from .instantiate_utils import (
|
|
125
137
|
instantiate,
|
|
126
138
|
is_instantiable,
|
|
@@ -129,7 +141,13 @@ if TYPE_CHECKING:
|
|
|
129
141
|
)
|
|
130
142
|
from .json import load_from_json, print_json, save_to_json
|
|
131
143
|
from .lazy_state_dict import LazyStateDict
|
|
132
|
-
from .misc import
|
|
144
|
+
from .misc import (
|
|
145
|
+
DeprecationWarningMeta,
|
|
146
|
+
first,
|
|
147
|
+
has_length,
|
|
148
|
+
join_lists,
|
|
149
|
+
validate_and_suggest_corrections,
|
|
150
|
+
)
|
|
133
151
|
from .packages import compare_versions, import_object
|
|
134
152
|
from .parameters import (
|
|
135
153
|
check_parameters_all_equal,
|
|
@@ -138,6 +156,7 @@ if TYPE_CHECKING:
|
|
|
138
156
|
get_parameter_summary,
|
|
139
157
|
human_readable,
|
|
140
158
|
print_parameters,
|
|
159
|
+
print_trainable_parameters,
|
|
141
160
|
state_dict_to_vector,
|
|
142
161
|
trainable_state_dict,
|
|
143
162
|
vector_to_state_dict,
|
fusion_bench/utils/data.py
CHANGED
|
@@ -95,7 +95,7 @@ class InfiniteDataLoader:
|
|
|
95
95
|
f"Failed to retrieve data from data loader after {self.max_retries} attempts. "
|
|
96
96
|
f"Last error: [{type(last_exception).__name__}]{last_exception}. "
|
|
97
97
|
+ (
|
|
98
|
-
f"The data loader
|
|
98
|
+
f"The data loader may be empty."
|
|
99
99
|
if isinstance(last_exception, StopIteration)
|
|
100
100
|
else ""
|
|
101
101
|
)
|
fusion_bench/utils/dict.py
CHANGED
|
@@ -41,3 +41,22 @@ def dict_map(f, d: dict, *, max_level: int = -1, skip_levels=0, inplace=False):
|
|
|
41
41
|
|
|
42
42
|
dict_map_impl(d, ans, 0)
|
|
43
43
|
return ans
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def dict_merge(dicts: Iterable[dict], disjoint: bool = True) -> dict:
|
|
47
|
+
"""Merge multiple dictionaries into one.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
dicts (Iterable[dict]): iterable of dictionaries to merge
|
|
51
|
+
disjoint (bool, optional): if True, raises an error on key conflicts. Defaults to True.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
dict: merged dictionary
|
|
55
|
+
"""
|
|
56
|
+
merged_dict = type(dicts[0])()
|
|
57
|
+
for d in dicts:
|
|
58
|
+
for k, v in d.items():
|
|
59
|
+
if disjoint and k in merged_dict:
|
|
60
|
+
raise ValueError(f"Key conflict when merging dictionaries: {k}")
|
|
61
|
+
merged_dict[k] = v
|
|
62
|
+
return merged_dict
|
fusion_bench/utils/dtype.py
CHANGED
|
@@ -146,3 +146,22 @@ def validate_expected_param_dtype(
|
|
|
146
146
|
raise ValueError(
|
|
147
147
|
f"Parameter {name} has dtype {param.dtype}, but expected {dtype}"
|
|
148
148
|
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def dtype_support_svd(dtype: torch.dtype) -> bool:
|
|
152
|
+
"""
|
|
153
|
+
Check if the given dtype is supported for SVD operation in PyTorch.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
dtype (torch.dtype): The data type to check.
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
bool: True if the dtype is supported for SVD, False otherwise.
|
|
160
|
+
"""
|
|
161
|
+
supported_dtypes = {
|
|
162
|
+
torch.float32,
|
|
163
|
+
torch.float64,
|
|
164
|
+
torch.complex64,
|
|
165
|
+
torch.complex128,
|
|
166
|
+
}
|
|
167
|
+
return dtype in supported_dtypes
|
|
@@ -1,4 +1,79 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
1
4
|
import hydra.core.hydra_config
|
|
5
|
+
from hydra import compose, initialize
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
|
|
8
|
+
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
9
|
+
|
|
10
|
+
log = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_default_config_path():
|
|
14
|
+
"""
|
|
15
|
+
Get the default configuration path by searching in common locations.
|
|
16
|
+
"""
|
|
17
|
+
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
18
|
+
for config_dir in ["config", "fusion_bench_config"]:
|
|
19
|
+
config_path = os.path.join(config_path_root, config_dir)
|
|
20
|
+
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
21
|
+
return os.path.abspath(config_path)
|
|
22
|
+
return None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def initialize_hydra_config(
|
|
26
|
+
config_name: str,
|
|
27
|
+
overrides: list[str] = None,
|
|
28
|
+
config_path: str = None,
|
|
29
|
+
return_hydra_config: bool = False,
|
|
30
|
+
) -> DictConfig:
|
|
31
|
+
"""
|
|
32
|
+
Load the Hydra configuration.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
config_name (str): The name of the configuration file (without .yaml extension).
|
|
36
|
+
overrides (list[str]): A list of configuration overrides.
|
|
37
|
+
config_path (str): The path to the configuration directory. If None, it will be automatically detected.
|
|
38
|
+
return_hydra_config (bool): If True, return the Hydra configuration object.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
DictConfig: The loaded configuration.
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> cfg = initialize_hydra_config(
|
|
45
|
+
... config_name="fabric_model_fusion",
|
|
46
|
+
... overrides=["method=dummy", "modelpool=dummy"],
|
|
47
|
+
... )
|
|
48
|
+
>>> print(cfg.method)
|
|
49
|
+
"""
|
|
50
|
+
if config_path is None:
|
|
51
|
+
config_path = get_default_config_path()
|
|
52
|
+
|
|
53
|
+
# check config_path validity
|
|
54
|
+
if config_path is None:
|
|
55
|
+
raise FileNotFoundError("Could not find configuration directory.")
|
|
56
|
+
if not os.path.isdir(config_path):
|
|
57
|
+
raise NotADirectoryError(
|
|
58
|
+
f"Configuration path {config_path} do not exists or is not a directory."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
if overrides is None:
|
|
62
|
+
overrides = []
|
|
63
|
+
|
|
64
|
+
with initialize(
|
|
65
|
+
version_base=None,
|
|
66
|
+
config_path=os.path.relpath(
|
|
67
|
+
config_path,
|
|
68
|
+
start=os.path.dirname(__file__),
|
|
69
|
+
),
|
|
70
|
+
):
|
|
71
|
+
cfg = compose(
|
|
72
|
+
config_name=config_name,
|
|
73
|
+
overrides=overrides,
|
|
74
|
+
return_hydra_config=return_hydra_config,
|
|
75
|
+
)
|
|
76
|
+
return cfg
|
|
2
77
|
|
|
3
78
|
|
|
4
79
|
def get_hydra_output_dir():
|
fusion_bench/utils/misc.py
CHANGED
fusion_bench/utils/packages.py
CHANGED
fusion_bench/utils/parameters.py
CHANGED
|
@@ -10,6 +10,7 @@ from .type import StateDictType
|
|
|
10
10
|
__all__ = [
|
|
11
11
|
"count_parameters",
|
|
12
12
|
"print_parameters",
|
|
13
|
+
"print_trainable_parameters",
|
|
13
14
|
"check_parameters_all_equal",
|
|
14
15
|
"get_parameter_statistics",
|
|
15
16
|
"state_dict_to_vector",
|
|
@@ -282,6 +283,38 @@ def print_parameters(
|
|
|
282
283
|
)
|
|
283
284
|
|
|
284
285
|
|
|
286
|
+
def print_trainable_parameters(
|
|
287
|
+
module: nn.Module,
|
|
288
|
+
is_human_readable: bool = True,
|
|
289
|
+
print_fn=print,
|
|
290
|
+
non_zero_only: bool = False,
|
|
291
|
+
):
|
|
292
|
+
"""
|
|
293
|
+
Print the names and number of trainable parameters in a PyTorch model.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
module (nn.Module): The PyTorch model.
|
|
297
|
+
is_human_readable (bool, optional): Whether to print the number of parameters in a human-readable format. Defaults to True.
|
|
298
|
+
print_fn (callable, optional): The function to use for printing. Defaults to print.
|
|
299
|
+
non_zero_only (bool, optional): Whether to count only non-zero parameters. Defaults to False.
|
|
300
|
+
|
|
301
|
+
Prints:
|
|
302
|
+
The names and number of trainable parameters in the model.
|
|
303
|
+
|
|
304
|
+
```python
|
|
305
|
+
print_trainable_parameters(model)
|
|
306
|
+
# weight: 1.50M parameters
|
|
307
|
+
# bias: 500.00K parameters
|
|
308
|
+
```
|
|
309
|
+
"""
|
|
310
|
+
for name, param in module.named_parameters():
|
|
311
|
+
if param.requires_grad:
|
|
312
|
+
num_params = _numel(param, non_zero_only=non_zero_only)
|
|
313
|
+
if is_human_readable:
|
|
314
|
+
num_params = human_readable(num_params)
|
|
315
|
+
print_fn(f"{name}: {num_params} parameters")
|
|
316
|
+
|
|
317
|
+
|
|
285
318
|
def check_parameters_all_equal(
|
|
286
319
|
list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
|
|
287
320
|
) -> None:
|
fusion_bench/utils/rich_utils.py
CHANGED
|
@@ -93,11 +93,11 @@ def print_bordered(
|
|
|
93
93
|
Print a message with a colored border.
|
|
94
94
|
|
|
95
95
|
Args:
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
96
|
+
message (str): The message to print.
|
|
97
|
+
title (str, optional): The title of the panel. Defaults to None.
|
|
98
|
+
style (str, optional): The color style for the border. Defaults to "cyan".
|
|
99
|
+
code_style (str, optional): The syntax highlighting style if the message is code.
|
|
100
|
+
Set to None for plain text. Defaults to "python".
|
|
101
101
|
"""
|
|
102
102
|
if code_style:
|
|
103
103
|
if format_code:
|
|
@@ -168,7 +168,7 @@ def print_config_tree(
|
|
|
168
168
|
"callbacks",
|
|
169
169
|
"logger",
|
|
170
170
|
"trainer",
|
|
171
|
-
"
|
|
171
|
+
"path",
|
|
172
172
|
"extras",
|
|
173
173
|
),
|
|
174
174
|
resolve: bool = False,
|
|
@@ -179,11 +179,20 @@ def print_config_tree(
|
|
|
179
179
|
) -> None:
|
|
180
180
|
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
|
181
181
|
|
|
182
|
-
:
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
182
|
+
Args:
|
|
183
|
+
cfg (DictConfig): A DictConfig composed by Hydra.
|
|
184
|
+
print_order (Sequence[str], optional): Determines in what order config components are printed.
|
|
185
|
+
Defaults to ``("data", "model", "callbacks", "logger", "trainer", "paths", "extras")``.
|
|
186
|
+
resolve (bool, optional): Whether to resolve reference fields of DictConfig.
|
|
187
|
+
Defaults to ``False``.
|
|
188
|
+
save_to_file (bool, optional): Whether to export config to the hydra output folder.
|
|
189
|
+
Defaults to ``False``.
|
|
190
|
+
theme (str, optional): The theme to use for syntax highlighting. Defaults to "monokai".
|
|
191
|
+
background_color (str, optional): The background color to use for syntax highlighting.
|
|
192
|
+
Defaults to "default".
|
|
193
|
+
|
|
194
|
+
Returns:
|
|
195
|
+
None
|
|
187
196
|
"""
|
|
188
197
|
style = "tree"
|
|
189
198
|
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
|
@@ -200,18 +209,13 @@ def print_config_tree(
|
|
|
200
209
|
)
|
|
201
210
|
)
|
|
202
211
|
|
|
203
|
-
# add all the other fields to queue (not specified in `print_order`)
|
|
204
|
-
for field in cfg:
|
|
205
|
-
if field not in queue:
|
|
206
|
-
queue.append(field)
|
|
207
|
-
|
|
208
212
|
# generate config tree from queue
|
|
209
213
|
for field in queue:
|
|
210
214
|
branch = tree.add(field, style=style, guide_style=style)
|
|
211
215
|
|
|
212
216
|
config_group = cfg[field]
|
|
213
217
|
if isinstance(config_group, DictConfig):
|
|
214
|
-
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
|
218
|
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve).strip()
|
|
215
219
|
else:
|
|
216
220
|
branch_content = str(config_group)
|
|
217
221
|
|
|
@@ -224,13 +228,32 @@ def print_config_tree(
|
|
|
224
228
|
)
|
|
225
229
|
)
|
|
226
230
|
|
|
231
|
+
# add all the other fields to queue (not specified in `print_order`)
|
|
232
|
+
other_fields = [field for field in cfg if field not in queue]
|
|
233
|
+
if other_fields:
|
|
234
|
+
others_branch = tree.add(Text("[others]"), style=style, guide_style=style)
|
|
235
|
+
|
|
236
|
+
other_cfg = OmegaConf.create({field: cfg[field] for field in other_fields})
|
|
237
|
+
branch_content = OmegaConf.to_yaml(other_cfg, resolve=resolve).strip()
|
|
238
|
+
|
|
239
|
+
others_branch.add(
|
|
240
|
+
rich.syntax.Syntax(
|
|
241
|
+
branch_content, "yaml", theme=theme, background_color=background_color
|
|
242
|
+
)
|
|
243
|
+
)
|
|
244
|
+
|
|
227
245
|
# print config tree
|
|
228
246
|
rich.print(tree)
|
|
229
247
|
|
|
230
248
|
# save config tree to file
|
|
231
249
|
if save_to_file:
|
|
232
|
-
|
|
233
|
-
|
|
250
|
+
if not cfg.get("paths") or not cfg.paths.get("output_dir"):
|
|
251
|
+
log.error(
|
|
252
|
+
"Cannot save config tree to file. 'paths.output_dir' is not specified in the config."
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
with open(Path(cfg.path.output_dir, "config_tree.log"), "w") as file:
|
|
256
|
+
rich.print(tree, file=file)
|
|
234
257
|
|
|
235
258
|
|
|
236
259
|
@rank_zero_only
|