fusion-bench 0.2.31__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +108 -3
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/scripts/cli.py +19 -8
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +18 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +51 -37
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.31.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -86,6 +86,9 @@ _import_structure = {
|
|
|
86
86
|
"set_print_function_call",
|
|
87
87
|
"set_print_function_call_permeanent",
|
|
88
88
|
"timeit_context",
|
|
89
|
+
"initialize_hydra_config",
|
|
90
|
+
"get_default_config_path",
|
|
91
|
+
"get_hydra_output_dir",
|
|
89
92
|
],
|
|
90
93
|
}
|
|
91
94
|
|
|
@@ -144,8 +147,11 @@ if TYPE_CHECKING:
|
|
|
144
147
|
StateDictType,
|
|
145
148
|
TorchModelType,
|
|
146
149
|
cache_with_joblib,
|
|
150
|
+
get_default_config_path,
|
|
151
|
+
get_hydra_output_dir,
|
|
147
152
|
get_rankzero_logger,
|
|
148
153
|
import_object,
|
|
154
|
+
initialize_hydra_config,
|
|
149
155
|
instantiate,
|
|
150
156
|
parse_dtype,
|
|
151
157
|
print_parameters,
|
fusion_bench/__main__.py
CHANGED
fusion_bench/dataset/__init__.py
CHANGED
|
@@ -38,10 +38,12 @@ _extra_objects = {
|
|
|
38
38
|
}
|
|
39
39
|
_import_structure = {
|
|
40
40
|
"clip_dataset": ["CLIPDataset"],
|
|
41
|
+
"image_dataset": ["ImageClassificationDataset"],
|
|
41
42
|
}
|
|
42
43
|
|
|
43
44
|
if TYPE_CHECKING:
|
|
44
45
|
from .clip_dataset import CLIPDataset
|
|
46
|
+
from .image_dataset import ImageClassificationDataset
|
|
45
47
|
|
|
46
48
|
else:
|
|
47
49
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -2,80 +2,12 @@
|
|
|
2
2
|
This module provides a class to convert a dataset whose object is a list of dictionaries with keys "image" and "label" to a dataset whose object is a tuple of tensors (inputs, label) for CLIP models.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from fusion_bench.utils import DeprecationWarningMeta
|
|
6
6
|
|
|
7
|
-
import
|
|
8
|
-
from torch.utils.data import Dataset
|
|
9
|
-
from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
|
|
7
|
+
from .image_dataset import ImageClassificationDataset
|
|
10
8
|
|
|
11
9
|
__all__ = ["CLIPDataset"]
|
|
12
10
|
|
|
13
11
|
|
|
14
|
-
class CLIPDataset(
|
|
15
|
-
|
|
16
|
-
A dataset class for CLIP models that converts a dataset of dictionaries or tuples
|
|
17
|
-
into a format suitable for CLIP processing.
|
|
18
|
-
|
|
19
|
-
This class wraps an existing dataset and applies CLIP preprocessing to the images.
|
|
20
|
-
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
|
|
21
|
-
or a tuple/list of (image, label).
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
dataset: The original dataset to wrap.
|
|
25
|
-
processor (CLIPProcessor): The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.
|
|
26
|
-
|
|
27
|
-
Attributes:
|
|
28
|
-
dataset: The wrapped dataset.
|
|
29
|
-
processor (CLIPProcessor): The CLIP processor used for image preprocessing.
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
|
|
33
|
-
self.dataset = dataset
|
|
34
|
-
self.processor = processor
|
|
35
|
-
|
|
36
|
-
def __len__(self):
|
|
37
|
-
"""Returns the number of items in the dataset."""
|
|
38
|
-
return len(self.dataset)
|
|
39
|
-
|
|
40
|
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
41
|
-
"""
|
|
42
|
-
Retrieves and processes an item from the dataset.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
idx (int): The index of the item to retrieve.
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
tuple: A tuple containing the processed image tensor and the label.
|
|
49
|
-
|
|
50
|
-
Raises:
|
|
51
|
-
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
|
|
52
|
-
"""
|
|
53
|
-
item = self.dataset[idx]
|
|
54
|
-
if isinstance(item, dict):
|
|
55
|
-
item = item
|
|
56
|
-
elif isinstance(item, (tuple, list)):
|
|
57
|
-
assert len(item) == 2, "Each item should be a tuple or list of length 2"
|
|
58
|
-
item = {"image": item[0], "label": item[1]}
|
|
59
|
-
else:
|
|
60
|
-
raise ValueError("Each item should be a dictionary or a tuple of length 2")
|
|
61
|
-
image = item["image"]
|
|
62
|
-
if self.processor is not None:
|
|
63
|
-
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
|
-
# Apply the processor to the image to get the input tensor
|
|
65
|
-
image = image.convert("RGB") # ensure image is in RGB format
|
|
66
|
-
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
67
|
-
"pixel_values"
|
|
68
|
-
][0]
|
|
69
|
-
elif callable(self.processor):
|
|
70
|
-
inputs = self.processor(image)
|
|
71
|
-
else:
|
|
72
|
-
raise ValueError(
|
|
73
|
-
"The processor should be a CLIPProcessor or a callable function"
|
|
74
|
-
)
|
|
75
|
-
else:
|
|
76
|
-
# if processor is None, return the raw image directly
|
|
77
|
-
inputs = image
|
|
78
|
-
# convert boolean label to int, this is for the case when the label is a binary classification task
|
|
79
|
-
if isinstance(item["label"], bool):
|
|
80
|
-
item["label"] = 1 if item["label"] else 0
|
|
81
|
-
return inputs, item["label"]
|
|
12
|
+
class CLIPDataset(ImageClassificationDataset, metaclass=DeprecationWarningMeta):
|
|
13
|
+
pass
|
|
@@ -1,35 +1,39 @@
|
|
|
1
|
-
from typing import Any, Callable, Tuple
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from torch.utils.data import Dataset
|
|
5
|
+
from transformers import BaseImageProcessor, ProcessorMixin
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
class
|
|
8
|
+
class ImageClassificationDataset(Dataset):
|
|
7
9
|
"""
|
|
8
|
-
A dataset class for image classification
|
|
10
|
+
A dataset class for image classification models that converts a dataset of dictionaries or tuples
|
|
11
|
+
into a format suitable for model processing.
|
|
9
12
|
|
|
10
|
-
This class wraps an existing dataset and applies
|
|
13
|
+
This class wraps an existing dataset and applies preprocessing to the images.
|
|
11
14
|
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
|
|
12
15
|
or a tuple/list of (image, label).
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
dataset: The original dataset to wrap.
|
|
16
|
-
transform (Callable): A function/transform to apply on the image.
|
|
17
|
-
|
|
18
|
-
Attributes:
|
|
19
|
-
dataset: The wrapped dataset.
|
|
20
|
-
transform (Callable): The transform to be applied to the images.
|
|
21
16
|
"""
|
|
22
17
|
|
|
23
|
-
def __init__(
|
|
24
|
-
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
dataset: Dataset,
|
|
21
|
+
processor: Optional[Union["ProcessorMixin", "BaseImageProcessor"]] = None,
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Args:
|
|
25
|
+
dataset (Dataset): The original dataset to wrap.
|
|
26
|
+
processor (Optional[Union[ProcessorMixin, BaseImageProcessor]]): The processor for preparing inputs.
|
|
27
|
+
If None, no preprocessing is applied and raw images are returned.
|
|
28
|
+
"""
|
|
25
29
|
self.dataset = dataset
|
|
26
|
-
self.
|
|
30
|
+
self.processor = processor
|
|
27
31
|
|
|
28
32
|
def __len__(self):
|
|
29
33
|
"""Returns the number of items in the dataset."""
|
|
30
34
|
return len(self.dataset)
|
|
31
35
|
|
|
32
|
-
def __getitem__(self, idx: int) -> Tuple[
|
|
36
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
33
37
|
"""
|
|
34
38
|
Retrieves and processes an item from the dataset.
|
|
35
39
|
|
|
@@ -37,11 +41,13 @@ class TransformedImageDataset(Dataset):
|
|
|
37
41
|
idx (int): The index of the item to retrieve.
|
|
38
42
|
|
|
39
43
|
Returns:
|
|
40
|
-
tuple: A tuple containing the processed image and the label.
|
|
44
|
+
tuple: A tuple containing the processed image tensor and the label.
|
|
41
45
|
|
|
42
46
|
Raises:
|
|
43
47
|
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
|
|
44
48
|
"""
|
|
49
|
+
# Standardize the item to a dictionary format
|
|
50
|
+
# {"image": ..., "label": ...}
|
|
45
51
|
item = self.dataset[idx]
|
|
46
52
|
if isinstance(item, dict):
|
|
47
53
|
item = item
|
|
@@ -50,6 +56,26 @@ class TransformedImageDataset(Dataset):
|
|
|
50
56
|
item = {"image": item[0], "label": item[1]}
|
|
51
57
|
else:
|
|
52
58
|
raise ValueError("Each item should be a dictionary or a tuple of length 2")
|
|
59
|
+
|
|
60
|
+
# Process the image using the provided processor, if any
|
|
53
61
|
image = item["image"]
|
|
54
|
-
|
|
62
|
+
if self.processor is not None:
|
|
63
|
+
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
|
+
# Apply the processor to the image to get the input tensor
|
|
65
|
+
image = image.convert("RGB") # ensure image is in RGB format
|
|
66
|
+
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
67
|
+
"pixel_values"
|
|
68
|
+
][0]
|
|
69
|
+
elif callable(self.processor):
|
|
70
|
+
inputs = self.processor(image)
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"The processor should be a transformers Processor or a callable function"
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
# if processor is None, return the raw image directly
|
|
77
|
+
inputs = image
|
|
78
|
+
# convert boolean label to int, this is for the case when the label is a binary classification task
|
|
79
|
+
if isinstance(item["label"], bool):
|
|
80
|
+
item["label"] = 1 if item["label"] else 0
|
|
55
81
|
return inputs, item["label"]
|
|
@@ -59,6 +59,10 @@ class BaseAlgorithm(BaseYAMLSerializable):
|
|
|
59
59
|
core fusion logic in the `run` method, while optional lifecycle hooks allow for
|
|
60
60
|
setup and cleanup operations.
|
|
61
61
|
|
|
62
|
+
If model has `_fusion_bench_target_modules` attribute, the algorithm will only fuse
|
|
63
|
+
the specified target modules. This is useful for models where only certain layers
|
|
64
|
+
should be fused (e.g., classification heads on top of a shared backbone are not merged).
|
|
65
|
+
|
|
62
66
|
Attributes:
|
|
63
67
|
_program: Optional program reference for algorithm execution context.
|
|
64
68
|
_config_key (str): Configuration key used for YAML serialization, defaults to "method".
|
fusion_bench/method/dop/dop.py
CHANGED
|
@@ -79,28 +79,6 @@ class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
|
79
79
|
), "The alpha should be in the range of [0, 1]"
|
|
80
80
|
super().__init__(**kwargs)
|
|
81
81
|
|
|
82
|
-
def print_params(self, pretrained_model):
|
|
83
|
-
total_params = 0
|
|
84
|
-
linear_params = 0
|
|
85
|
-
linear_weight_params = 0
|
|
86
|
-
for module_name, module in pretrained_model.named_modules():
|
|
87
|
-
if not is_leaf_module(module):
|
|
88
|
-
continue
|
|
89
|
-
if isinstance(module, nn.Linear):
|
|
90
|
-
linear_params += sum(p.numel() for n, p in module.named_parameters())
|
|
91
|
-
linear_weight_params += sum(
|
|
92
|
-
p.numel() for n, p in module.named_parameters() if "weight" in n
|
|
93
|
-
)
|
|
94
|
-
total_params += sum(p.numel() for p in module.parameters())
|
|
95
|
-
|
|
96
|
-
linear_ratio = linear_params / total_params * 100
|
|
97
|
-
linear_weight_ratio = linear_weight_params / total_params * 100
|
|
98
|
-
print(f"Total Parameters: {total_params}")
|
|
99
|
-
print(f"Linear Parameters: {linear_params}")
|
|
100
|
-
print(f"Linear Weight Parameters: {linear_weight_params}")
|
|
101
|
-
print(f"Linear Ratio: {linear_ratio:.2f}%")
|
|
102
|
-
print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
|
|
103
|
-
|
|
104
82
|
def run(self, modelpool: BaseModelPool):
|
|
105
83
|
if self.seed is not None:
|
|
106
84
|
L.seed_everything(self.seed)
|