fusion-bench 0.2.30__py3-none-any.whl → 0.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +6 -0
- fusion_bench/__main__.py +2 -2
- fusion_bench/constants/runtime.py +4 -1
- fusion_bench/dataset/__init__.py +2 -0
- fusion_bench/dataset/clip_dataset.py +4 -72
- fusion_bench/dataset/image_dataset.py +44 -18
- fusion_bench/method/base_algorithm.py +4 -0
- fusion_bench/method/classification/image_classification_finetune.py +1 -0
- fusion_bench/method/concrete_subspace/clip_concrete_tsvm.py +285 -0
- fusion_bench/method/dop/dop.py +0 -22
- fusion_bench/method/dop/dop_general.py +489 -0
- fusion_bench/method/dop/utils.py +24 -4
- fusion_bench/method/emr_merging/__init__.py +1 -0
- fusion_bench/method/emr_merging/emr_merging.py +53 -0
- fusion_bench/method/emr_merging/utils.py +162 -0
- fusion_bench/method/opcm/opcm.py +6 -2
- fusion_bench/method/opcm/opcm_general.py +356 -0
- fusion_bench/method/opcm/utils.py +1 -4
- fusion_bench/method/simple_average.py +52 -18
- fusion_bench/method/task_arithmetic/task_arithmetic.py +1 -1
- fusion_bench/method/task_singular_vector/TSVM.py +7 -6
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +0 -1
- fusion_bench/mixins/lightning_fabric.py +110 -11
- fusion_bench/mixins/openclip_classification.py +155 -1
- fusion_bench/mixins/serialization.py +1 -1
- fusion_bench/modelpool/base_pool.py +37 -0
- fusion_bench/modelpool/convnext_for_image_classification.py +5 -2
- fusion_bench/modelpool/openclip_vision/modelpool.py +12 -3
- fusion_bench/models/hf_clip.py +20 -0
- fusion_bench/models/modulator/__init__.py +1 -0
- fusion_bench/models/modulator/base.py +123 -0
- fusion_bench/models/open_clip/modeling.py +61 -5
- fusion_bench/models/open_clip/utils.py +13 -2
- fusion_bench/models/parameter_dict.py +119 -29
- fusion_bench/models/utils.py +190 -2
- fusion_bench/models/wrappers/switch.py +90 -0
- fusion_bench/programs/base_program.py +6 -0
- fusion_bench/programs/fabric_fusion_program.py +4 -0
- fusion_bench/py.typed +1 -0
- fusion_bench/scripts/cli.py +25 -23
- fusion_bench/scripts/imgui.py +2 -2
- fusion_bench/scripts/webui.py +2 -2
- fusion_bench/taskpool/image_classification.py +270 -0
- fusion_bench/utils/__init__.py +20 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/dict.py +19 -0
- fusion_bench/utils/dtype.py +19 -0
- fusion_bench/utils/hydra_utils.py +75 -0
- fusion_bench/utils/misc.py +1 -0
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/rich_utils.py +42 -19
- fusion_bench/utils/state_dict_arithmetic.py +183 -1
- fusion_bench/utils/tensorboard.py +21 -3
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/METADATA +3 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/RECORD +70 -53
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/WHEEL +1 -1
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/entry_points.txt +1 -1
- fusion_bench_config/README.md +9 -0
- fusion_bench_config/fabric/auto.yaml +1 -0
- fusion_bench_config/fabric/loggers/mlflow_logger.yaml +4 -0
- fusion_bench_config/hydra/default.yaml +3 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_tsvm.yaml +38 -0
- fusion_bench_config/method/dop/dop_general.yaml +33 -0
- fusion_bench_config/method/emr_merging/emr_merging.yaml +1 -0
- fusion_bench_config/method/opcm/opcm_general.yaml +18 -0
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224_8-tasks.yaml +15 -0
- fusion_bench_config/taskpool/ImageClassificationTaskPool/convnext-base-224_8-tasks.yaml +17 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.30.dist-info → fusion_bench-0.2.32.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -86,6 +86,9 @@ _import_structure = {
|
|
|
86
86
|
"set_print_function_call",
|
|
87
87
|
"set_print_function_call_permeanent",
|
|
88
88
|
"timeit_context",
|
|
89
|
+
"initialize_hydra_config",
|
|
90
|
+
"get_default_config_path",
|
|
91
|
+
"get_hydra_output_dir",
|
|
89
92
|
],
|
|
90
93
|
}
|
|
91
94
|
|
|
@@ -144,8 +147,11 @@ if TYPE_CHECKING:
|
|
|
144
147
|
StateDictType,
|
|
145
148
|
TorchModelType,
|
|
146
149
|
cache_with_joblib,
|
|
150
|
+
get_default_config_path,
|
|
151
|
+
get_hydra_output_dir,
|
|
147
152
|
get_rankzero_logger,
|
|
148
153
|
import_object,
|
|
154
|
+
initialize_hydra_config,
|
|
149
155
|
instantiate,
|
|
150
156
|
parse_dtype,
|
|
151
157
|
print_parameters,
|
fusion_bench/__main__.py
CHANGED
|
@@ -89,7 +89,10 @@ class RuntimeConstants:
|
|
|
89
89
|
self._initialized = True
|
|
90
90
|
|
|
91
91
|
debug = False
|
|
92
|
-
"""
|
|
92
|
+
"""
|
|
93
|
+
Global debug flag for enabling verbose logging and debugging features.
|
|
94
|
+
Use `RuntimeConstants().debug` instead of `RuntimeConstants.debug`
|
|
95
|
+
"""
|
|
93
96
|
|
|
94
97
|
@property
|
|
95
98
|
def cache_dir(self) -> Path:
|
fusion_bench/dataset/__init__.py
CHANGED
|
@@ -38,10 +38,12 @@ _extra_objects = {
|
|
|
38
38
|
}
|
|
39
39
|
_import_structure = {
|
|
40
40
|
"clip_dataset": ["CLIPDataset"],
|
|
41
|
+
"image_dataset": ["ImageClassificationDataset"],
|
|
41
42
|
}
|
|
42
43
|
|
|
43
44
|
if TYPE_CHECKING:
|
|
44
45
|
from .clip_dataset import CLIPDataset
|
|
46
|
+
from .image_dataset import ImageClassificationDataset
|
|
45
47
|
|
|
46
48
|
else:
|
|
47
49
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -2,80 +2,12 @@
|
|
|
2
2
|
This module provides a class to convert a dataset whose object is a list of dictionaries with keys "image" and "label" to a dataset whose object is a tuple of tensors (inputs, label) for CLIP models.
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from fusion_bench.utils import DeprecationWarningMeta
|
|
6
6
|
|
|
7
|
-
import
|
|
8
|
-
from torch.utils.data import Dataset
|
|
9
|
-
from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
|
|
7
|
+
from .image_dataset import ImageClassificationDataset
|
|
10
8
|
|
|
11
9
|
__all__ = ["CLIPDataset"]
|
|
12
10
|
|
|
13
11
|
|
|
14
|
-
class CLIPDataset(
|
|
15
|
-
|
|
16
|
-
A dataset class for CLIP models that converts a dataset of dictionaries or tuples
|
|
17
|
-
into a format suitable for CLIP processing.
|
|
18
|
-
|
|
19
|
-
This class wraps an existing dataset and applies CLIP preprocessing to the images.
|
|
20
|
-
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
|
|
21
|
-
or a tuple/list of (image, label).
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
dataset: The original dataset to wrap.
|
|
25
|
-
processor (CLIPProcessor): The CLIP processor for preparing inputs. If None, no preprocessing is applied and raw images are returned.
|
|
26
|
-
|
|
27
|
-
Attributes:
|
|
28
|
-
dataset: The wrapped dataset.
|
|
29
|
-
processor (CLIPProcessor): The CLIP processor used for image preprocessing.
|
|
30
|
-
"""
|
|
31
|
-
|
|
32
|
-
def __init__(self, dataset: Dataset, processor: Optional[CLIPProcessor] = None):
|
|
33
|
-
self.dataset = dataset
|
|
34
|
-
self.processor = processor
|
|
35
|
-
|
|
36
|
-
def __len__(self):
|
|
37
|
-
"""Returns the number of items in the dataset."""
|
|
38
|
-
return len(self.dataset)
|
|
39
|
-
|
|
40
|
-
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
41
|
-
"""
|
|
42
|
-
Retrieves and processes an item from the dataset.
|
|
43
|
-
|
|
44
|
-
Args:
|
|
45
|
-
idx (int): The index of the item to retrieve.
|
|
46
|
-
|
|
47
|
-
Returns:
|
|
48
|
-
tuple: A tuple containing the processed image tensor and the label.
|
|
49
|
-
|
|
50
|
-
Raises:
|
|
51
|
-
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
|
|
52
|
-
"""
|
|
53
|
-
item = self.dataset[idx]
|
|
54
|
-
if isinstance(item, dict):
|
|
55
|
-
item = item
|
|
56
|
-
elif isinstance(item, (tuple, list)):
|
|
57
|
-
assert len(item) == 2, "Each item should be a tuple or list of length 2"
|
|
58
|
-
item = {"image": item[0], "label": item[1]}
|
|
59
|
-
else:
|
|
60
|
-
raise ValueError("Each item should be a dictionary or a tuple of length 2")
|
|
61
|
-
image = item["image"]
|
|
62
|
-
if self.processor is not None:
|
|
63
|
-
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
|
-
# Apply the processor to the image to get the input tensor
|
|
65
|
-
image = image.convert("RGB") # ensure image is in RGB format
|
|
66
|
-
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
67
|
-
"pixel_values"
|
|
68
|
-
][0]
|
|
69
|
-
elif callable(self.processor):
|
|
70
|
-
inputs = self.processor(image)
|
|
71
|
-
else:
|
|
72
|
-
raise ValueError(
|
|
73
|
-
"The processor should be a CLIPProcessor or a callable function"
|
|
74
|
-
)
|
|
75
|
-
else:
|
|
76
|
-
# if processor is None, return the raw image directly
|
|
77
|
-
inputs = image
|
|
78
|
-
# convert boolean label to int, this is for the case when the label is a binary classification task
|
|
79
|
-
if isinstance(item["label"], bool):
|
|
80
|
-
item["label"] = 1 if item["label"] else 0
|
|
81
|
-
return inputs, item["label"]
|
|
12
|
+
class CLIPDataset(ImageClassificationDataset, metaclass=DeprecationWarningMeta):
|
|
13
|
+
pass
|
|
@@ -1,35 +1,39 @@
|
|
|
1
|
-
from typing import Any, Callable, Tuple
|
|
1
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from torch.utils.data import Dataset
|
|
5
|
+
from transformers import BaseImageProcessor, ProcessorMixin
|
|
4
6
|
|
|
5
7
|
|
|
6
|
-
class
|
|
8
|
+
class ImageClassificationDataset(Dataset):
|
|
7
9
|
"""
|
|
8
|
-
A dataset class for image classification
|
|
10
|
+
A dataset class for image classification models that converts a dataset of dictionaries or tuples
|
|
11
|
+
into a format suitable for model processing.
|
|
9
12
|
|
|
10
|
-
This class wraps an existing dataset and applies
|
|
13
|
+
This class wraps an existing dataset and applies preprocessing to the images.
|
|
11
14
|
It expects each item in the dataset to be either a dictionary with 'image' and 'label' keys,
|
|
12
15
|
or a tuple/list of (image, label).
|
|
13
|
-
|
|
14
|
-
Args:
|
|
15
|
-
dataset: The original dataset to wrap.
|
|
16
|
-
transform (Callable): A function/transform to apply on the image.
|
|
17
|
-
|
|
18
|
-
Attributes:
|
|
19
|
-
dataset: The wrapped dataset.
|
|
20
|
-
transform (Callable): The transform to be applied to the images.
|
|
21
16
|
"""
|
|
22
17
|
|
|
23
|
-
def __init__(
|
|
24
|
-
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
dataset: Dataset,
|
|
21
|
+
processor: Optional[Union["ProcessorMixin", "BaseImageProcessor"]] = None,
|
|
22
|
+
):
|
|
23
|
+
"""
|
|
24
|
+
Args:
|
|
25
|
+
dataset (Dataset): The original dataset to wrap.
|
|
26
|
+
processor (Optional[Union[ProcessorMixin, BaseImageProcessor]]): The processor for preparing inputs.
|
|
27
|
+
If None, no preprocessing is applied and raw images are returned.
|
|
28
|
+
"""
|
|
25
29
|
self.dataset = dataset
|
|
26
|
-
self.
|
|
30
|
+
self.processor = processor
|
|
27
31
|
|
|
28
32
|
def __len__(self):
|
|
29
33
|
"""Returns the number of items in the dataset."""
|
|
30
34
|
return len(self.dataset)
|
|
31
35
|
|
|
32
|
-
def __getitem__(self, idx: int) -> Tuple[
|
|
36
|
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
|
|
33
37
|
"""
|
|
34
38
|
Retrieves and processes an item from the dataset.
|
|
35
39
|
|
|
@@ -37,11 +41,13 @@ class TransformedImageDataset(Dataset):
|
|
|
37
41
|
idx (int): The index of the item to retrieve.
|
|
38
42
|
|
|
39
43
|
Returns:
|
|
40
|
-
tuple: A tuple containing the processed image and the label.
|
|
44
|
+
tuple: A tuple containing the processed image tensor and the label.
|
|
41
45
|
|
|
42
46
|
Raises:
|
|
43
47
|
ValueError: If the item is neither a dictionary nor a tuple/list of length 2.
|
|
44
48
|
"""
|
|
49
|
+
# Standardize the item to a dictionary format
|
|
50
|
+
# {"image": ..., "label": ...}
|
|
45
51
|
item = self.dataset[idx]
|
|
46
52
|
if isinstance(item, dict):
|
|
47
53
|
item = item
|
|
@@ -50,6 +56,26 @@ class TransformedImageDataset(Dataset):
|
|
|
50
56
|
item = {"image": item[0], "label": item[1]}
|
|
51
57
|
else:
|
|
52
58
|
raise ValueError("Each item should be a dictionary or a tuple of length 2")
|
|
59
|
+
|
|
60
|
+
# Process the image using the provided processor, if any
|
|
53
61
|
image = item["image"]
|
|
54
|
-
|
|
62
|
+
if self.processor is not None:
|
|
63
|
+
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
|
+
# Apply the processor to the image to get the input tensor
|
|
65
|
+
image = image.convert("RGB") # ensure image is in RGB format
|
|
66
|
+
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
67
|
+
"pixel_values"
|
|
68
|
+
][0]
|
|
69
|
+
elif callable(self.processor):
|
|
70
|
+
inputs = self.processor(image)
|
|
71
|
+
else:
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"The processor should be a transformers Processor or a callable function"
|
|
74
|
+
)
|
|
75
|
+
else:
|
|
76
|
+
# if processor is None, return the raw image directly
|
|
77
|
+
inputs = image
|
|
78
|
+
# convert boolean label to int, this is for the case when the label is a binary classification task
|
|
79
|
+
if isinstance(item["label"], bool):
|
|
80
|
+
item["label"] = 1 if item["label"] else 0
|
|
55
81
|
return inputs, item["label"]
|
|
@@ -59,6 +59,10 @@ class BaseAlgorithm(BaseYAMLSerializable):
|
|
|
59
59
|
core fusion logic in the `run` method, while optional lifecycle hooks allow for
|
|
60
60
|
setup and cleanup operations.
|
|
61
61
|
|
|
62
|
+
If model has `_fusion_bench_target_modules` attribute, the algorithm will only fuse
|
|
63
|
+
the specified target modules. This is useful for models where only certain layers
|
|
64
|
+
should be fused (e.g., classification heads on top of a shared backbone are not merged).
|
|
65
|
+
|
|
62
66
|
Attributes:
|
|
63
67
|
_program: Optional program reference for algorithm execution context.
|
|
64
68
|
_config_key (str): Configuration key used for YAML serialization, defaults to "method".
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Literal, Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import DictConfig
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
from fusion_bench import (
|
|
11
|
+
BaseAlgorithm,
|
|
12
|
+
OpenCLIPClassificationMixin,
|
|
13
|
+
OpenCLIPVisionModelPool,
|
|
14
|
+
SimpleProfilerMixin,
|
|
15
|
+
StateDictType,
|
|
16
|
+
auto_register_config,
|
|
17
|
+
get_rankzero_logger,
|
|
18
|
+
instantiate,
|
|
19
|
+
)
|
|
20
|
+
from fusion_bench.method.adamerging.entropy_loss import entropy_loss
|
|
21
|
+
from fusion_bench.method.task_singular_vector import TaskSingularVectorMerging
|
|
22
|
+
from fusion_bench.method.task_singular_vector.utils import (
|
|
23
|
+
TSVM_utils,
|
|
24
|
+
check_parameterNamesMatch,
|
|
25
|
+
check_state_dicts_equal,
|
|
26
|
+
state_dict_to_vector,
|
|
27
|
+
vector_to_state_dict,
|
|
28
|
+
)
|
|
29
|
+
from fusion_bench.models.masks import MaskModel, mask_sparsity
|
|
30
|
+
from fusion_bench.models.open_clip import (
|
|
31
|
+
ClassificationHead,
|
|
32
|
+
ImageClassifier,
|
|
33
|
+
ImageEncoder,
|
|
34
|
+
)
|
|
35
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
36
|
+
TaskWiseMergedModel,
|
|
37
|
+
get_task_wise_weights,
|
|
38
|
+
)
|
|
39
|
+
from fusion_bench.utils.devices import clear_cuda_cache
|
|
40
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
41
|
+
from fusion_bench.utils.parameters import print_parameters, print_trainable_parameters
|
|
42
|
+
from fusion_bench.utils.rich_utils import print_config_yaml
|
|
43
|
+
from fusion_bench.utils.state_dict_arithmetic import (
|
|
44
|
+
_validate_state_dict_same_keys,
|
|
45
|
+
state_dict_add,
|
|
46
|
+
state_dict_hadamard_product,
|
|
47
|
+
state_dict_mul,
|
|
48
|
+
state_dict_sub,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
log = get_rankzero_logger(__name__)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
@auto_register_config
|
|
55
|
+
class ConcreteTSVMForOpenCLIP(
|
|
56
|
+
OpenCLIPClassificationMixin,
|
|
57
|
+
SimpleProfilerMixin,
|
|
58
|
+
BaseAlgorithm,
|
|
59
|
+
):
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
dataloader_kwargs: DictConfig,
|
|
63
|
+
optimizer: DictConfig,
|
|
64
|
+
lr_scheduler: DictConfig,
|
|
65
|
+
max_steps: int,
|
|
66
|
+
save_interval: int,
|
|
67
|
+
initial_logits: float,
|
|
68
|
+
temperature: float,
|
|
69
|
+
eval_mask_type: Literal["continuous", "discrete"],
|
|
70
|
+
mask_checkpoint: Optional[str],
|
|
71
|
+
merge_dtype: str,
|
|
72
|
+
clamp_weights: bool,
|
|
73
|
+
tie_weights: bool,
|
|
74
|
+
strict: bool,
|
|
75
|
+
skip_training: bool,
|
|
76
|
+
# === TSVM parameters ===
|
|
77
|
+
exclude_keys: Optional[List[str]],
|
|
78
|
+
alpha: float,
|
|
79
|
+
return_single_task_models: bool = True,
|
|
80
|
+
**kwargs,
|
|
81
|
+
):
|
|
82
|
+
super().__init__(**kwargs)
|
|
83
|
+
if not return_single_task_models:
|
|
84
|
+
log.warning("return_single_task_models is forced to be True here.")
|
|
85
|
+
self.return_single_task_models = True
|
|
86
|
+
|
|
87
|
+
@torch.no_grad()
|
|
88
|
+
def setup_models(self):
|
|
89
|
+
"""
|
|
90
|
+
load the pre-trained model, task vectors, and construct the mask model.
|
|
91
|
+
"""
|
|
92
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
93
|
+
modelpool = self.modelpool
|
|
94
|
+
|
|
95
|
+
# load the pre-trained model
|
|
96
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
97
|
+
self.set_clip_processor(stage="test", processor=pretrained_model.val_preprocess)
|
|
98
|
+
|
|
99
|
+
# constrcute mask model
|
|
100
|
+
mask_model = MaskModel(
|
|
101
|
+
pretrained_model, ignore_untrained_params=True, parameter_type="logits"
|
|
102
|
+
)
|
|
103
|
+
if merge_dtype is not None:
|
|
104
|
+
mask_model.to(merge_dtype)
|
|
105
|
+
mask_model.fill_(self.initial_logits)
|
|
106
|
+
|
|
107
|
+
if self.fabric.is_global_zero:
|
|
108
|
+
print("summary of mask model:")
|
|
109
|
+
print_parameters(mask_model)
|
|
110
|
+
|
|
111
|
+
if self.fabric.is_global_zero:
|
|
112
|
+
tsvm_algo = TaskSingularVectorMerging(
|
|
113
|
+
alpha=self.alpha,
|
|
114
|
+
exclude_keys=self.exclude_keys,
|
|
115
|
+
return_single_task_models=self.return_single_task_models,
|
|
116
|
+
)
|
|
117
|
+
tsvm_algo._fabric_instance = self.fabric
|
|
118
|
+
models = tsvm_algo.run(modelpool)
|
|
119
|
+
|
|
120
|
+
finetuned_models = [models[name] for name in modelpool.model_names]
|
|
121
|
+
|
|
122
|
+
task_wise_weight = get_task_wise_weights(
|
|
123
|
+
num_models=len(modelpool.model_names),
|
|
124
|
+
init_values=self.alpha,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# create a wrapped model
|
|
128
|
+
module = TaskWiseMergedModel(
|
|
129
|
+
task_wise_weight=task_wise_weight,
|
|
130
|
+
pretrained_model=pretrained_model,
|
|
131
|
+
finetuned_models=finetuned_models,
|
|
132
|
+
clamp_weights=self.clamp_weights,
|
|
133
|
+
tie_weights=self.tie_weights,
|
|
134
|
+
strict=self.strict,
|
|
135
|
+
task_vector_dtype=merge_dtype,
|
|
136
|
+
)
|
|
137
|
+
module = module.to(dtype=merge_dtype)
|
|
138
|
+
|
|
139
|
+
print("trainable parameter summary of merged model (TaskWiseMergedModel):")
|
|
140
|
+
print_trainable_parameters(module)
|
|
141
|
+
else:
|
|
142
|
+
module = None
|
|
143
|
+
|
|
144
|
+
with torch.no_grad():
|
|
145
|
+
self.fabric.barrier()
|
|
146
|
+
module = self.fabric.broadcast(module, src=0)
|
|
147
|
+
|
|
148
|
+
return module, mask_model
|
|
149
|
+
|
|
150
|
+
def train_mask(self, module: TaskWiseMergedModel, mask_model: MaskModel):
|
|
151
|
+
"""
|
|
152
|
+
Train the mask model using the provided module.
|
|
153
|
+
|
|
154
|
+
This method configures the optimizer, sets up the mask model, and performs test-time adaptation to train the mask model.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
module (TaskWiseMergedModel): The wrapped model with task-wise weights.
|
|
158
|
+
mask_model (MaskModel): The mask model to be trained.
|
|
159
|
+
"""
|
|
160
|
+
config = self.config
|
|
161
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
162
|
+
log.info(f"Using merge dtype: {merge_dtype}")
|
|
163
|
+
|
|
164
|
+
optimizer: "torch.optim.Optimizer" = instantiate(
|
|
165
|
+
self.optimizer,
|
|
166
|
+
params=filter(lambda p: p.requires_grad, mask_model.parameters()),
|
|
167
|
+
)
|
|
168
|
+
print(f"{optimizer=}")
|
|
169
|
+
if self.lr_scheduler is not None:
|
|
170
|
+
lr_scheduler = instantiate(
|
|
171
|
+
self.lr_scheduler,
|
|
172
|
+
optimizer=optimizer,
|
|
173
|
+
)
|
|
174
|
+
print(f"{lr_scheduler=}")
|
|
175
|
+
else:
|
|
176
|
+
lr_scheduler = None
|
|
177
|
+
|
|
178
|
+
log.info("Setup models and optimizer with Fabric.")
|
|
179
|
+
mask_model, optimizer = self.fabric.setup(mask_model, optimizer)
|
|
180
|
+
|
|
181
|
+
log.info("Move the merged module to the correct device and disable gradients.")
|
|
182
|
+
module.requires_grad_(False)
|
|
183
|
+
module.to(mask_model.device)
|
|
184
|
+
|
|
185
|
+
mask_model.train()
|
|
186
|
+
optimizer.zero_grad()
|
|
187
|
+
for step_idx in (
|
|
188
|
+
pbar := tqdm(
|
|
189
|
+
range(self.config.max_steps if not self.is_debug_mode else 5),
|
|
190
|
+
("[DEBUG MODE] " if self.is_debug_mode else "")
|
|
191
|
+
+ "Concrete TSVM Test-Time Adaptation",
|
|
192
|
+
dynamic_ncols=True,
|
|
193
|
+
disable=not self.fabric.is_global_zero,
|
|
194
|
+
)
|
|
195
|
+
):
|
|
196
|
+
metrics = {}
|
|
197
|
+
# sample a shared mask and merge weights
|
|
198
|
+
with self.profile("sample mask"):
|
|
199
|
+
mask = mask_model.sample_mask(
|
|
200
|
+
mask_type="continuous", temperature=config.temperature
|
|
201
|
+
)
|
|
202
|
+
metrics["train/sparsity"] = mask_sparsity(mask)
|
|
203
|
+
with self.profile("merge weights"):
|
|
204
|
+
# rescale mask
|
|
205
|
+
for name, m in mask.items():
|
|
206
|
+
mask[name] = m / torch.mean(m)
|
|
207
|
+
module.merge_weights(task_vector_mask=mask)
|
|
208
|
+
|
|
209
|
+
# ------ inner optimization goes here ------
|
|
210
|
+
# NOTE:
|
|
211
|
+
# Because the algorithmic parameters of TSVM are assumed to be chosen on a validation test
|
|
212
|
+
# set, we do not need to perform inner optimization here. So here we skip the inner optimization step.
|
|
213
|
+
# ------------------------------------------
|
|
214
|
+
|
|
215
|
+
total_loss = None
|
|
216
|
+
for task in self.modelpool.model_names:
|
|
217
|
+
with self.profile("data loading"):
|
|
218
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
219
|
+
# NOTE: The labels are not allowed to be used during test-time adaptation
|
|
220
|
+
images = batch[0].to(dtype=merge_dtype)
|
|
221
|
+
with self.profile("forward pass"):
|
|
222
|
+
logits = self.compute_logits(module, images, task)
|
|
223
|
+
loss = entropy_loss(logits)
|
|
224
|
+
total_loss = loss if total_loss is None else total_loss + loss
|
|
225
|
+
|
|
226
|
+
with self.profile("compute grad"):
|
|
227
|
+
self.fabric.backward(total_loss)
|
|
228
|
+
|
|
229
|
+
with self.profile("optimizer step"):
|
|
230
|
+
optimizer.step()
|
|
231
|
+
optimizer.zero_grad()
|
|
232
|
+
|
|
233
|
+
if lr_scheduler is not None:
|
|
234
|
+
lr_scheduler.step()
|
|
235
|
+
|
|
236
|
+
metrics.update({"train/loss": loss.item()})
|
|
237
|
+
self.fabric.log_dict(metrics, step=step_idx)
|
|
238
|
+
pbar.set_postfix(metrics)
|
|
239
|
+
|
|
240
|
+
if (step_idx + 1) % self.config.save_interval == 0:
|
|
241
|
+
with self.profiler.profile("save checkpoint"):
|
|
242
|
+
save_dir = os.path.join(self.fabric.logger.log_dir, "checkpoints")
|
|
243
|
+
if not os.path.exists(save_dir):
|
|
244
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
245
|
+
save_path = os.path.join(save_dir, f"mask_steps_{step_idx}.pt")
|
|
246
|
+
print(f"saving checkpoint to {save_path}")
|
|
247
|
+
state = {"model": mask_model}
|
|
248
|
+
self.fabric.save(save_path, state)
|
|
249
|
+
|
|
250
|
+
# Create or update a symbolic link to the latest checkpoint
|
|
251
|
+
if self.fabric.is_global_zero:
|
|
252
|
+
symlink_path = os.path.join(save_dir, "latest_checkpoint.pt")
|
|
253
|
+
if os.path.exists(symlink_path):
|
|
254
|
+
os.remove(symlink_path)
|
|
255
|
+
os.link(os.path.abspath(save_path), symlink_path)
|
|
256
|
+
|
|
257
|
+
self.print_profile_summary()
|
|
258
|
+
|
|
259
|
+
def run(self, modelpool: OpenCLIPVisionModelPool):
|
|
260
|
+
self.modelpool = modelpool
|
|
261
|
+
merge_dtype = parse_dtype(self.merge_dtype)
|
|
262
|
+
|
|
263
|
+
with self.profile("setup models"):
|
|
264
|
+
module, mask_model = self.setup_models()
|
|
265
|
+
self.setup_zero_shot_classification_head(freeze=True, dtype=merge_dtype)
|
|
266
|
+
|
|
267
|
+
if self.mask_checkpoint is None:
|
|
268
|
+
if not self.skip_training:
|
|
269
|
+
clear_cuda_cache()
|
|
270
|
+
self.train_mask(module, mask_model=mask_model)
|
|
271
|
+
else:
|
|
272
|
+
if self.fabric.is_global_zero:
|
|
273
|
+
print("loading mask from checkpoint", self.mask_checkpoint)
|
|
274
|
+
self.fabric.load(self.mask_checkpoint, {"model": mask_model})
|
|
275
|
+
|
|
276
|
+
with torch.no_grad():
|
|
277
|
+
clear_cuda_cache()
|
|
278
|
+
mask = mask_model.sample_mask(
|
|
279
|
+
mask_type=self.eval_mask_type, temperature=self.temperature
|
|
280
|
+
)
|
|
281
|
+
# rescale mask
|
|
282
|
+
for name, m in mask.items():
|
|
283
|
+
mask[name] = m / torch.mean(m)
|
|
284
|
+
model = module.merge_and_unload(mask)
|
|
285
|
+
return model.to(dtype=torch.float32)
|
fusion_bench/method/dop/dop.py
CHANGED
|
@@ -79,28 +79,6 @@ class ContinualDOPForCLIP(BaseAlgorithm, LightningFabricMixin):
|
|
|
79
79
|
), "The alpha should be in the range of [0, 1]"
|
|
80
80
|
super().__init__(**kwargs)
|
|
81
81
|
|
|
82
|
-
def print_params(self, pretrained_model):
|
|
83
|
-
total_params = 0
|
|
84
|
-
linear_params = 0
|
|
85
|
-
linear_weight_params = 0
|
|
86
|
-
for module_name, module in pretrained_model.named_modules():
|
|
87
|
-
if not is_leaf_module(module):
|
|
88
|
-
continue
|
|
89
|
-
if isinstance(module, nn.Linear):
|
|
90
|
-
linear_params += sum(p.numel() for n, p in module.named_parameters())
|
|
91
|
-
linear_weight_params += sum(
|
|
92
|
-
p.numel() for n, p in module.named_parameters() if "weight" in n
|
|
93
|
-
)
|
|
94
|
-
total_params += sum(p.numel() for p in module.parameters())
|
|
95
|
-
|
|
96
|
-
linear_ratio = linear_params / total_params * 100
|
|
97
|
-
linear_weight_ratio = linear_weight_params / total_params * 100
|
|
98
|
-
print(f"Total Parameters: {total_params}")
|
|
99
|
-
print(f"Linear Parameters: {linear_params}")
|
|
100
|
-
print(f"Linear Weight Parameters: {linear_weight_params}")
|
|
101
|
-
print(f"Linear Ratio: {linear_ratio:.2f}%")
|
|
102
|
-
print(f"Linear Weight Ratio: {linear_weight_ratio:.2f}%")
|
|
103
|
-
|
|
104
82
|
def run(self, modelpool: BaseModelPool):
|
|
105
83
|
if self.seed is not None:
|
|
106
84
|
L.seed_everything(self.seed)
|