fusion-bench 0.2.19__py3-none-any.whl → 0.2.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/linear/simple_average_for_llama.py +14 -3
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +192 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +365 -0
- fusion_bench/method/simple_average.py +18 -2
- fusion_bench/modelpool/clip_vision/modelpool.py +45 -12
- fusion_bench/scripts/cli.py +1 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +146 -0
- fusion_bench/utils/state_dict_arithmetic.py +10 -5
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/METADATA +9 -1
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/RECORD +44 -39
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.20.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Optional, Union
|
|
3
|
+
from typing import Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
6
|
from lightning.fabric.utilities import rank_zero_only
|
|
@@ -11,6 +11,9 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
13
|
from fusion_bench.utils import instantiate, timeit_context
|
|
14
|
+
from fusion_bench.utils.modelscope import (
|
|
15
|
+
resolve_repo_path,
|
|
16
|
+
)
|
|
14
17
|
|
|
15
18
|
from ..base_pool import BaseModelPool
|
|
16
19
|
|
|
@@ -25,25 +28,32 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
25
28
|
the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
|
|
26
29
|
"""
|
|
27
30
|
|
|
28
|
-
_config_mapping = BaseModelPool._config_mapping | {
|
|
31
|
+
_config_mapping = BaseModelPool._config_mapping | {
|
|
32
|
+
"_processor": "processor",
|
|
33
|
+
"_platform": "hf",
|
|
34
|
+
}
|
|
29
35
|
|
|
30
36
|
def __init__(
|
|
31
37
|
self,
|
|
32
38
|
models: DictConfig,
|
|
33
39
|
*,
|
|
34
40
|
processor: Optional[DictConfig] = None,
|
|
41
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
35
42
|
**kwargs,
|
|
36
43
|
):
|
|
37
44
|
super().__init__(models, **kwargs)
|
|
38
|
-
|
|
39
45
|
self._processor = processor
|
|
46
|
+
self._platform = platform
|
|
40
47
|
|
|
41
48
|
def load_processor(self, *args, **kwargs) -> CLIPProcessor:
|
|
42
49
|
assert self._processor is not None, "Processor is not defined in the config"
|
|
43
50
|
if isinstance(self._processor, str):
|
|
44
51
|
if rank_zero_only.rank == 0:
|
|
45
52
|
log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
|
|
46
|
-
|
|
53
|
+
repo_path = resolve_repo_path(
|
|
54
|
+
repo_id=self._processor, repo_type="model", platform=self._platform
|
|
55
|
+
)
|
|
56
|
+
processor = CLIPProcessor.from_pretrained(repo_path, *args, **kwargs)
|
|
47
57
|
else:
|
|
48
58
|
processor = instantiate(self._processor, *args, **kwargs)
|
|
49
59
|
return processor
|
|
@@ -54,7 +64,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
54
64
|
if isinstance(model_config, str):
|
|
55
65
|
if rank_zero_only.rank == 0:
|
|
56
66
|
log.info(f"Loading `transformers.CLIPModel`: {model_config}")
|
|
57
|
-
|
|
67
|
+
repo_path = resolve_repo_path(
|
|
68
|
+
repo_id=model_config, repo_type="model", platform=self._platform
|
|
69
|
+
)
|
|
70
|
+
clip_model = CLIPModel.from_pretrained(repo_path, *args, **kwargs)
|
|
58
71
|
return clip_model
|
|
59
72
|
else:
|
|
60
73
|
assert isinstance(
|
|
@@ -107,14 +120,17 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
107
120
|
if isinstance(model, str):
|
|
108
121
|
if rank_zero_only.rank == 0:
|
|
109
122
|
log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
|
|
110
|
-
|
|
123
|
+
repo_path = resolve_repo_path(
|
|
124
|
+
model, repo_type="model", platform=self._platform
|
|
125
|
+
)
|
|
126
|
+
return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
|
|
111
127
|
if isinstance(model, nn.Module):
|
|
112
128
|
if rank_zero_only.rank == 0:
|
|
113
129
|
log.info(f"Returning existing model: {model}")
|
|
114
130
|
return model
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
131
|
+
else:
|
|
132
|
+
# If the model is not a string, we use the default load_model method
|
|
133
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
118
134
|
|
|
119
135
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
120
136
|
dataset_config = self._train_datasets[dataset_name]
|
|
@@ -123,7 +139,7 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
123
139
|
log.info(
|
|
124
140
|
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
125
141
|
)
|
|
126
|
-
dataset =
|
|
142
|
+
dataset = self._load_dataset(dataset_config, split="train")
|
|
127
143
|
else:
|
|
128
144
|
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
129
145
|
return dataset
|
|
@@ -135,7 +151,7 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
135
151
|
log.info(
|
|
136
152
|
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
137
153
|
)
|
|
138
|
-
dataset =
|
|
154
|
+
dataset = self._load_dataset(dataset_config, split="validation")
|
|
139
155
|
else:
|
|
140
156
|
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
141
157
|
return dataset
|
|
@@ -147,7 +163,24 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
147
163
|
log.info(
|
|
148
164
|
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
149
165
|
)
|
|
150
|
-
dataset =
|
|
166
|
+
dataset = self._load_dataset(dataset_config, split="test")
|
|
151
167
|
else:
|
|
152
168
|
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
153
169
|
return dataset
|
|
170
|
+
|
|
171
|
+
def _load_dataset(self, name: str, split: str):
|
|
172
|
+
"""
|
|
173
|
+
Load a dataset by its name and split.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
dataset_name (str): The name of the dataset.
|
|
177
|
+
split (str): The split of the dataset to load (e.g., "train", "validation", "test").
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
Dataset: The loaded dataset.
|
|
181
|
+
"""
|
|
182
|
+
datset_dir = resolve_repo_path(
|
|
183
|
+
name, repo_type="dataset", platform=self._platform
|
|
184
|
+
)
|
|
185
|
+
dataset = load_dataset(datset_dir, split=split)
|
|
186
|
+
return dataset
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
"""
|
|
3
|
-
This is the CLI script that is executed when the user runs the `
|
|
3
|
+
This is the CLI script that is executed when the user runs the `fusion_bench` command.
|
|
4
4
|
The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
|
|
5
5
|
"""
|
|
6
6
|
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
-
from
|
|
1
|
+
from difflib import get_close_matches
|
|
2
|
+
from typing import Any, Iterable, List, Optional
|
|
2
3
|
|
|
3
|
-
__all__ = [
|
|
4
|
+
__all__ = [
|
|
5
|
+
"first",
|
|
6
|
+
"has_length",
|
|
7
|
+
"join_list",
|
|
8
|
+
"attr_equal",
|
|
9
|
+
"validate_and_suggest_corrections",
|
|
10
|
+
]
|
|
4
11
|
|
|
5
12
|
|
|
6
13
|
def first(iterable: Iterable):
|
|
@@ -41,3 +48,42 @@ def attr_equal(obj, attr: str, value):
|
|
|
41
48
|
if not hasattr(obj, attr):
|
|
42
49
|
return False
|
|
43
50
|
return getattr(obj, attr) == value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def validate_and_suggest_corrections(
|
|
54
|
+
obj: Any, values: Iterable[Any], *, max_suggestions: int = 3, cutoff: float = 0.6
|
|
55
|
+
) -> Any:
|
|
56
|
+
"""
|
|
57
|
+
Return *obj* if it is contained in *values*.
|
|
58
|
+
Otherwise raise a helpful ``ValueError`` that lists the closest matches.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
obj : Any
|
|
62
|
+
The value to validate.
|
|
63
|
+
values : Iterable[Any]
|
|
64
|
+
The set of allowed values.
|
|
65
|
+
max_suggestions : int, optional
|
|
66
|
+
How many typo-hints to include at most (default 3).
|
|
67
|
+
cutoff : float, optional
|
|
68
|
+
Similarity threshold for suggestions (0.0-1.0, default 0.6).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The original *obj* if it is valid.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: With a friendly message that points out possible typos.
|
|
75
|
+
"""
|
|
76
|
+
# Normalise to a list so we can reuse it
|
|
77
|
+
value_list = list(values)
|
|
78
|
+
|
|
79
|
+
if obj in value_list:
|
|
80
|
+
return obj
|
|
81
|
+
|
|
82
|
+
# Build suggestions
|
|
83
|
+
str_values = list(map(str, value_list))
|
|
84
|
+
matches = get_close_matches(str(obj), str_values, n=max_suggestions, cutoff=cutoff)
|
|
85
|
+
|
|
86
|
+
msg = f"Invalid value {obj!r}. Allowed values: {value_list}"
|
|
87
|
+
if matches:
|
|
88
|
+
msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
|
|
89
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from datasets import load_dataset as datasets_load_dataset
|
|
5
|
+
from fusion_bench.utils import validate_and_suggest_corrections
|
|
6
|
+
|
|
7
|
+
try:
|
|
8
|
+
from modelscope import snapshot_download as modelscope_snapshot_download
|
|
9
|
+
except ImportError:
|
|
10
|
+
|
|
11
|
+
def modelscope_snapshot_download(*args, **kwargs):
|
|
12
|
+
raise ImportError(
|
|
13
|
+
"ModelScope is not installed. Please install it using `pip install modelscope` to use ModelScope models."
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
try:
|
|
18
|
+
from huggingface_hub import snapshot_download as huggingface_snapshot_download
|
|
19
|
+
except ImportError:
|
|
20
|
+
|
|
21
|
+
def huggingface_snapshot_download(*args, **kwargs):
|
|
22
|
+
raise ImportError(
|
|
23
|
+
"Hugging Face Hub is not installed. Please install it using `pip install huggingface_hub` to use Hugging Face models."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"load_dataset",
|
|
29
|
+
"resolve_repo_path",
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
AVAILABLE_PLATFORMS = ["hf", "huggingface", "modelscope"]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def load_dataset(
|
|
36
|
+
name: str,
|
|
37
|
+
split: str = "train",
|
|
38
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
39
|
+
):
|
|
40
|
+
"""
|
|
41
|
+
Load a dataset from Hugging Face or ModelScope.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
platform (Literal['hf', 'modelscope']): The platform to load the dataset from.
|
|
45
|
+
name (str): The name of the dataset.
|
|
46
|
+
split (str): The split of the dataset to load (default is "train").
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
Dataset: The loaded dataset.
|
|
50
|
+
"""
|
|
51
|
+
validate_and_suggest_corrections(platform, AVAILABLE_PLATFORMS)
|
|
52
|
+
if platform == "hf" or platform == "huggingface":
|
|
53
|
+
return datasets_load_dataset(name, split=split)
|
|
54
|
+
elif platform == "modelscope":
|
|
55
|
+
dataset_dir = modelscope_snapshot_download(name, repo_type="dataset")
|
|
56
|
+
return datasets_load_dataset(dataset_dir, split=split)
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"Unsupported platform: {platform}. Supported platforms are 'hf', 'huggingface', and 'modelscope'."
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def resolve_repo_path(
|
|
64
|
+
repo_id: str,
|
|
65
|
+
repo_type: Optional[str] = "model",
|
|
66
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
67
|
+
**kwargs,
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Resolve and download a repository from various platforms to a local path.
|
|
71
|
+
|
|
72
|
+
This function handles multiple repository sources including local paths, Hugging Face,
|
|
73
|
+
and ModelScope. It automatically downloads remote repositories to local cache and
|
|
74
|
+
returns the local path for further use.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
repo_id (str): Repository identifier. Can be:
|
|
78
|
+
- Local file/directory path (returned as-is if exists)
|
|
79
|
+
- Hugging Face model/dataset ID (e.g., "bert-base-uncased")
|
|
80
|
+
- ModelScope model/dataset ID
|
|
81
|
+
- URL-prefixed ID (e.g., "hf://model-name", "modelscope://model-name").
|
|
82
|
+
The prefix will override the platform argument.
|
|
83
|
+
repo_type (str, optional): Type of repository to download. Defaults to "model".
|
|
84
|
+
Common values include "model" and "dataset".
|
|
85
|
+
platform (Literal["hf", "huggingface", "modelscope"], optional):
|
|
86
|
+
Platform to download from. Defaults to "hf". Options:
|
|
87
|
+
- "hf" or "huggingface": Hugging Face Hub
|
|
88
|
+
- "modelscope": ModelScope platform
|
|
89
|
+
**kwargs: Additional arguments passed to the underlying download functions.
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
str: Local path to the repository (either existing local path or downloaded cache path).
|
|
93
|
+
|
|
94
|
+
Raises:
|
|
95
|
+
FileNotFoundError: If the repository cannot be found or downloaded from any platform.
|
|
96
|
+
ValueError: If an unsupported platform is specified.
|
|
97
|
+
ImportError: If required dependencies for the specified platform are not installed.
|
|
98
|
+
|
|
99
|
+
Examples:
|
|
100
|
+
>>> # Local path (returned as-is)
|
|
101
|
+
>>> resolve_repo_path("/path/to/local/model")
|
|
102
|
+
"/path/to/local/model"
|
|
103
|
+
|
|
104
|
+
>>> # Hugging Face model
|
|
105
|
+
>>> resolve_repo_path("bert-base-uncased")
|
|
106
|
+
"/home/user/.cache/huggingface/hub/models--bert-base-uncased/..."
|
|
107
|
+
|
|
108
|
+
>>> # ModelScope model with explicit platform
|
|
109
|
+
>>> resolve_repo_path("damo/nlp_bert_backbone_base_std", platform="modelscope")
|
|
110
|
+
"/home/user/.cache/modelscope/hub/damo/nlp_bert_backbone_base_std/..."
|
|
111
|
+
|
|
112
|
+
>>> # URL-prefixed repository ID
|
|
113
|
+
>>> resolve_repo_path("hf://microsoft/DialoGPT-medium")
|
|
114
|
+
"/home/user/.cache/huggingface/hub/models--microsoft--DialoGPT-medium/..."
|
|
115
|
+
"""
|
|
116
|
+
# If it's a HuggingFace Hub model id, download snapshot
|
|
117
|
+
if repo_id.startswith("hf://") or repo_id.startswith("huggingface://"):
|
|
118
|
+
repo_id = repo_id.replace("hf://", "").replace("huggingface://", "")
|
|
119
|
+
platform = "hf"
|
|
120
|
+
# If it's a ModelScope model id, download snapshot
|
|
121
|
+
elif repo_id.startswith("modelscope://"):
|
|
122
|
+
repo_id = repo_id.replace("modelscope://", "")
|
|
123
|
+
platform = "modelscope"
|
|
124
|
+
|
|
125
|
+
# If it's a local file or directory, return as is
|
|
126
|
+
if os.path.exists(repo_id):
|
|
127
|
+
return repo_id
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
validate_and_suggest_corrections(platform, AVAILABLE_PLATFORMS)
|
|
131
|
+
# This will download the model to the cache and return the local path
|
|
132
|
+
if platform in ["hf", "huggingface"]:
|
|
133
|
+
local_path = huggingface_snapshot_download(
|
|
134
|
+
repo_id=repo_id, repo_type=repo_type, **kwargs
|
|
135
|
+
)
|
|
136
|
+
elif platform == "modelscope":
|
|
137
|
+
local_path = modelscope_snapshot_download(
|
|
138
|
+
repo_id=repo_id, repo_type=repo_type, **kwargs
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(
|
|
142
|
+
f"Unsupported platform: {platform}. Supported platforms are 'hf', 'huggingface', and 'modelscope'."
|
|
143
|
+
)
|
|
144
|
+
return local_path
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise FileNotFoundError(f"Could not resolve checkpoint: {repo_id}. Error: {e}")
|
|
@@ -4,6 +4,7 @@ from typing import Callable, Dict, List, Literal, Union, cast
|
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import Tensor
|
|
7
|
+
from tqdm.auto import tqdm
|
|
7
8
|
|
|
8
9
|
from .parameters import check_parameters_all_equal
|
|
9
10
|
from .type import BoolStateDictType, StateDictType
|
|
@@ -124,7 +125,11 @@ def state_dict_sub(
|
|
|
124
125
|
|
|
125
126
|
|
|
126
127
|
def state_dict_add(
|
|
127
|
-
a: StateDictType,
|
|
128
|
+
a: StateDictType,
|
|
129
|
+
b: StateDictType,
|
|
130
|
+
strict: bool = True,
|
|
131
|
+
device=None,
|
|
132
|
+
show_pbar: bool = False,
|
|
128
133
|
):
|
|
129
134
|
"""
|
|
130
135
|
Returns the sum of two state dicts.
|
|
@@ -140,10 +145,10 @@ def state_dict_add(
|
|
|
140
145
|
ans = {}
|
|
141
146
|
if strict:
|
|
142
147
|
check_parameters_all_equal([a, b])
|
|
143
|
-
for key in a:
|
|
148
|
+
for key in tqdm(tuple(a.keys())) if show_pbar else a:
|
|
144
149
|
ans[key] = a[key] + b[key]
|
|
145
150
|
else:
|
|
146
|
-
for key in a:
|
|
151
|
+
for key in tqdm(tuple(a.keys())) if show_pbar else a:
|
|
147
152
|
if key in b:
|
|
148
153
|
ans[key] = a[key] + b[key]
|
|
149
154
|
if device is not None:
|
|
@@ -175,7 +180,7 @@ def state_dict_mul(state_dict: StateDictType, scalar: float):
|
|
|
175
180
|
return diff
|
|
176
181
|
|
|
177
182
|
|
|
178
|
-
def state_dict_div(state_dict: StateDictType, scalar: float):
|
|
183
|
+
def state_dict_div(state_dict: StateDictType, scalar: float, show_pbar: bool = False):
|
|
179
184
|
"""
|
|
180
185
|
Returns the division of a state dict by a scalar.
|
|
181
186
|
|
|
@@ -187,7 +192,7 @@ def state_dict_div(state_dict: StateDictType, scalar: float):
|
|
|
187
192
|
Dict: The division of the state dict by the scalar.
|
|
188
193
|
"""
|
|
189
194
|
diff = OrderedDict()
|
|
190
|
-
for k in state_dict:
|
|
195
|
+
for k in tqdm(tuple(state_dict.keys())) if show_pbar else state_dict:
|
|
191
196
|
diff[k] = state_dict[k] / scalar
|
|
192
197
|
return diff
|
|
193
198
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.20
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -72,6 +72,14 @@ FusionBench is a benchmark suite designed to evaluate the performance of various
|
|
|
72
72
|
|
|
73
73
|
Projects based on FusionBench and news from the community (descending order of date. If you have any work based on FusionBench, please feel free to let us know, we are willing to add it to the list. :partying_face:):
|
|
74
74
|
|
|
75
|
+
<details>
|
|
76
|
+
<summary>The-Hai Nguyen, Dang Huu-Tien, Takeshi Suzuki, and Le-Minh Nguyen. RegMean++: Enhancing Effectiveness and Generalization of Regression Mean for Model Merging. Aug, 2025. https://www.arxiv.org/abs/2508.03121</summary>
|
|
77
|
+
|
|
78
|
+
Regression Mean (RegMean), an approach that formulates model merging as a linear regression problem, aims to find the optimal weights for each linear layer in the merge model by minimizing the discrepancy in predictions between the merge and candidate models. RegMean provides a precise closed-form solution for the merging problem; therefore, it offers explainability and computational efficiency. However, RegMean merges each linear layer independently, overlooking how the features and information in the earlier layers propagate through the layers and influence the final prediction in the merge model. In this paper, we introduce RegMean++, a simple yet effective alternative to RegMean, that explicitly incorporates both intra- and cross-layer dependencies between merge models' layers into RegMean's objective. By accounting for these dependencies, RegMean++ better captures the behaviors of the merge model. Extensive experiments demonstrate that RegMean++ consistently outperforms RegMean across diverse settings, including in-domain (ID) and out-of-domain (OOD) generalization, sequential merging, large-scale tasks, and robustness under several types of distribution shifts. Furthermore, RegMean++ achieves competitive or state-of-the-art performance compared to various recent advanced model merging methods.
|
|
79
|
+
|
|
80
|
+
<img width="1000" alt="image" src="docs/algorithms/images/regmean_vs_regmean_plusplus.png">
|
|
81
|
+
</details>
|
|
82
|
+
|
|
75
83
|
<details>
|
|
76
84
|
<summary>Hao Mark Chen, et al. FW-Merging: Scaling Model Merging with Frank-Wolfe Optimization. Mar 2025. https://arxiv.org/abs/2503.12649</summary>
|
|
77
85
|
|
|
@@ -43,12 +43,12 @@ fusion_bench/dataset/llama/stanford_shp.py,sha256=6ueXKnFXIBBobacU1h5WxGLZrSOtBk
|
|
|
43
43
|
fusion_bench/dataset/llama/ultrachat.py,sha256=Go7WvrDAYnm184fdazHGRYLbSY6Xd7jrESyQeUJtOww,1736
|
|
44
44
|
fusion_bench/dataset/llama/wikitext.py,sha256=9ZHR-nMfXRumd3o-PIj3n7B83YlVeqpGkZ2zJs2B-9Y,2883
|
|
45
45
|
fusion_bench/dataset/llama/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
46
|
-
fusion_bench/method/__init__.py,sha256=
|
|
46
|
+
fusion_bench/method/__init__.py,sha256=o8t4R4fSTVaYuv2c0uD0_5ZolGuRk3nLawanYZghoCk,8015
|
|
47
47
|
fusion_bench/method/base_algorithm.py,sha256=UuITuGnSskcKEwUVINuPoWJUwqGm9AIgyQIOCu8BMks,1162
|
|
48
48
|
fusion_bench/method/dummy.py,sha256=hb1y6LR_geRZ5eRgGwt5zJUcHYorCeIbs5i76CvurUc,1031
|
|
49
49
|
fusion_bench/method/ensemble.py,sha256=rGxvJTeorfcBuE_e0XO-0-MAc9un7ZCC46ikKGuAcN4,3077
|
|
50
50
|
fusion_bench/method/model_recombination.py,sha256=2tviqmYSPOL0_Ktv8_gt_YzQ4tyCANHxXquUot_3Cgo,5360
|
|
51
|
-
fusion_bench/method/simple_average.py,sha256=
|
|
51
|
+
fusion_bench/method/simple_average.py,sha256=omUIZn7VnrerAHbo9lNGUlL5EcpUeGv3C-fJczcXaOk,5458
|
|
52
52
|
fusion_bench/method/ada_svd/__init__.py,sha256=4XzQbbvE9HI3NtEmEFvo8iC3ds_85vJXe7P7qJfL7kk,77
|
|
53
53
|
fusion_bench/method/ada_svd/clip_vision.py,sha256=XvXgIdlShAREMsubRgphyycGrhWqSnuVBo6S9bNYSd0,12581
|
|
54
54
|
fusion_bench/method/adamerging/__init__.py,sha256=nt0saBT_3bqghk-pINQ-XCWm9UWwSZllu4R1sDuAJAA,376
|
|
@@ -121,7 +121,7 @@ fusion_bench/method/linear/__init__.py,sha256=ChfkoOEAb-rUKwpowFPel-a1hRfS8gCrbn
|
|
|
121
121
|
fusion_bench/method/linear/expo.py,sha256=LCHTWlsPm1Mjhrq0mfpWLVC7skkI9ZksGduy3TxULoU,3939
|
|
122
122
|
fusion_bench/method/linear/linear_interpolation.py,sha256=IONw9BPiRJouY8bE9Abfyz7qVI_1B1n8KGZa0f7Pza8,2157
|
|
123
123
|
fusion_bench/method/linear/llama_expo.py,sha256=ccECjhAqcFmzOIDyZ7e_aPzTM2Kj8u2D8TJytyz18YM,8476
|
|
124
|
-
fusion_bench/method/linear/simple_average_for_llama.py,sha256=
|
|
124
|
+
fusion_bench/method/linear/simple_average_for_llama.py,sha256=ZbkTEgTNYBq8p3q_cxA6gky0mdZs7FaGfHlZKfFEbiQ,2952
|
|
125
125
|
fusion_bench/method/linear/task_arithmetic_for_llama.py,sha256=4SZpiTD7OzhWUXtcdK3PYdXbBGyDqiZd7oZOQ0lraN0,1963
|
|
126
126
|
fusion_bench/method/lm_finetune/__init__.py,sha256=IFGAqXujX3Fabzl_tC6zZyOyPFJfVziL0qFtj5MVxj0,149
|
|
127
127
|
fusion_bench/method/lm_finetune/bradley_terry_rm.py,sha256=ys_td1IeL3bzPTE0Cixlj2JooCaB7qseRwSDwroAk5A,18777
|
|
@@ -183,6 +183,9 @@ fusion_bench/method/regmean/__init__.py,sha256=VVqAkdHkb005Sc2XmeiedQYzb3q5aQNI8
|
|
|
183
183
|
fusion_bench/method/regmean/clip_regmean.py,sha256=xhT7dYSCg9sPLL5ZUCCtcA-Ypw4PBHsOivrnz-3fDso,4931
|
|
184
184
|
fusion_bench/method/regmean/gpt2_regmean.py,sha256=p2D3E8YAZsltsI6GM474UWNqPZfBqihLZ93ZLUpOJ_c,5565
|
|
185
185
|
fusion_bench/method/regmean/regmean.py,sha256=IqkweSS4WXGbMquWaHWHijiZ6rMIaHfFSje8C8bbb6g,16587
|
|
186
|
+
fusion_bench/method/regmean_plusplus/__init__.py,sha256=TD_KeTGXkbsh7tIsFDAnTDmYSvkc1ii2HfYQeNCu58A,141
|
|
187
|
+
fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py,sha256=s0HMUHTg6MV8B_WnF6D8LHRhrbr9aHChxMhnDj8qweE,7385
|
|
188
|
+
fusion_bench/method/regmean_plusplus/regmean_plusplus.py,sha256=and0YbngsWdq3zLBPOqTkhsCxR4WI9FjuoDQbLQfR4M,16069
|
|
186
189
|
fusion_bench/method/slerp/__init__.py,sha256=Wgl9gg01Xou6jyZeBRD98kRnB_dAADDaPqRTHoONx9o,59
|
|
187
190
|
fusion_bench/method/slerp/slerp.py,sha256=2_n10REnRoV5DuwCC0bDX8RM3MLL4Q_5rZiU0hICw2w,3406
|
|
188
191
|
fusion_bench/method/slerp/slerp_utils.py,sha256=vksRo6n7FqY7By9aqbwTL4XV3BjcU_GrUl_r85Kpfjc,3504
|
|
@@ -255,7 +258,7 @@ fusion_bench/modelpool/nyuv2_modelpool.py,sha256=btuXmYxwfjI6MnGakhoOf53Iyb9fxYH
|
|
|
255
258
|
fusion_bench/modelpool/causal_lm/__init__.py,sha256=F432-aDIgAbUITj4GNZS9dgUKKhaDMCbTeHB-9MecaQ,99
|
|
256
259
|
fusion_bench/modelpool/causal_lm/causal_lm.py,sha256=7-mUWVGVsXyljH_06CmIyReClKx_xVjy5zeXTJcLQIk,8085
|
|
257
260
|
fusion_bench/modelpool/clip_vision/__init__.py,sha256=3b9gN2bWUsoA1EmpitnIMnIlX7nklxbkn4WJ0QJtS2c,43
|
|
258
|
-
fusion_bench/modelpool/clip_vision/modelpool.py,sha256=
|
|
261
|
+
fusion_bench/modelpool/clip_vision/modelpool.py,sha256=eemYxDPCovbdKo1yr96UJqieWxr7l_n0rQHsJT4JMno,7117
|
|
259
262
|
fusion_bench/modelpool/openclip_vision/__init__.py,sha256=QDmAitKqUwRygN9QncdS_kGWZdfTKL4uUifC8xh9c10,47
|
|
260
263
|
fusion_bench/modelpool/openclip_vision/modelpool.py,sha256=2MieB4PMvg85DaiYu49m3BzuBjib1xozJHTpYyHhRTs,11102
|
|
261
264
|
fusion_bench/modelpool/seq2seq_lm/__init__.py,sha256=FnfSMHcwNHDQEMdB2HdK4WphQ6MufsRLUkczuALjM4Q,57
|
|
@@ -343,7 +346,7 @@ fusion_bench/programs/__init__.py,sha256=oGoRp2TMI6ELxyfkeTg2h27hZJEDz9x31Asmvwv
|
|
|
343
346
|
fusion_bench/programs/base_program.py,sha256=0dX_KcMWASo53pr-ldzfUBWIjEXy6oeDWZBrfc7FIk8,195
|
|
344
347
|
fusion_bench/programs/fabric_fusion_program.py,sha256=978t9Fw9kvw-Il7rJLR2jNI1OfSxkhq1c5-5D4BgnYU,13813
|
|
345
348
|
fusion_bench/scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
346
|
-
fusion_bench/scripts/cli.py,sha256=
|
|
349
|
+
fusion_bench/scripts/cli.py,sha256=gjNghi2VSW2gg41bn-WjzSkmTE3pDLoFk9c61zLJNso,1154
|
|
347
350
|
fusion_bench/scripts/imgui.py,sha256=r9Glbfbwu3JCsX9TKQFwcHarvwA_G7ff0jWBUPW1S1U,7613
|
|
348
351
|
fusion_bench/scripts/nyuv2_mtl_train.py,sha256=W1C45R9NdF4O-UjCx1bUxRTdFE0-FlRpwJHZ5gY18rI,3602
|
|
349
352
|
fusion_bench/scripts/webui.py,sha256=ryA-2leSnHcYA88tTAYzJGDhiljbi0vl1Fibejzndlw,14398
|
|
@@ -419,14 +422,15 @@ fusion_bench/utils/instantiate_utils.py,sha256=57D8YP25OO-ArltOSsHDKtnNcA44m1yAq
|
|
|
419
422
|
fusion_bench/utils/json.py,sha256=sVCqbm9mmyHybiui-O57KFt_ULrjLtN2wipSo6VDvqE,2533
|
|
420
423
|
fusion_bench/utils/lazy_imports.py,sha256=v5l9cpHXPMaz1IVBmB5oOqefYr9vA3XvP340xT7Wy18,2796
|
|
421
424
|
fusion_bench/utils/lazy_state_dict.py,sha256=Hu8PkhbJcUikXJxWUJ7vabu2uDbnUUF6UsRS0k8i71U,16841
|
|
422
|
-
fusion_bench/utils/misc.py,sha256=
|
|
425
|
+
fusion_bench/utils/misc.py,sha256=93q0m-HYWkPK91Co5lll_J0Dxs6YahW2lD_X8fUAyTk,2420
|
|
426
|
+
fusion_bench/utils/modelscope.py,sha256=EZfvP6ExpagQXX2s0mbyE_yfmMHntQwvBN9taznnRmE,5705
|
|
423
427
|
fusion_bench/utils/packages.py,sha256=L64paDi1SmeT3gRvRV6LaqB8AeGdzIYWIRI31qSQbSk,2110
|
|
424
428
|
fusion_bench/utils/parameters.py,sha256=2vs8vo2o-nRA9NOMOYFye-X8-aHQZoYe54tM6n0r0RE,11757
|
|
425
429
|
fusion_bench/utils/path.py,sha256=hRA1CPHNnTYBUmzbftH77sHvn4aTuybEK5Tth1skP-k,531
|
|
426
430
|
fusion_bench/utils/pylogger.py,sha256=amlRsdqHpOjxmBl6f9TA8y0LaWelEWgQNcGgEGsVOIc,3333
|
|
427
431
|
fusion_bench/utils/rich_utils.py,sha256=B8DhAYuVp23pG6ZnnYrUhcL-ikHZoQeTNqlM7u4pwwU,5786
|
|
428
432
|
fusion_bench/utils/set.py,sha256=_43ZvGKJ_BK9sUslsSNhi7xEfuAQuyj3vViImnGpnCY,134
|
|
429
|
-
fusion_bench/utils/state_dict_arithmetic.py,sha256=
|
|
433
|
+
fusion_bench/utils/state_dict_arithmetic.py,sha256=7iWb9cy-eYJgiltGJXnVCrLoNHfIMXZEL_AOg6OY0dw,11613
|
|
430
434
|
fusion_bench/utils/tensorboard.py,sha256=9fkgNYR9LM38nPNkudcxL9TjLUseW-280M0k2nLff7o,1669
|
|
431
435
|
fusion_bench/utils/timer.py,sha256=RC2hP8JqaibdL0FnRyUCBRf4m7CXyfn5tE16zBWZ7hg,1338
|
|
432
436
|
fusion_bench/utils/type.py,sha256=2iu8PQzSzI2KopYwg4Pay7qpq7s_LKkl6Rhj-tjG3u0,630
|
|
@@ -437,7 +441,7 @@ fusion_bench/utils/plot/token_notebook.py,sha256=bsntXf46Zz_RavTxNiB9c3-KvHw7LFw
|
|
|
437
441
|
fusion_bench/utils/strenum/__init__.py,sha256=id9ORi1uXrDxhbmVxitJ1KDwLS4H3AAwFpaK5h1cQzw,8531
|
|
438
442
|
fusion_bench/utils/strenum/_name_mangler.py,sha256=o11M5-bURW2RBvRTYXFQIPNeqLzburdoWLIqk8X3ydw,3397
|
|
439
443
|
fusion_bench/utils/strenum/_version.py,sha256=6JQRo9LcvODbCOeVFYQb9HNJ_J9XiG_Zbn8ws2A3BV8,18466
|
|
440
|
-
fusion_bench-0.2.
|
|
444
|
+
fusion_bench-0.2.20.dist-info/licenses/LICENSE,sha256=nhnOJlw4CPuPVE0qvkGmxfFgHmKi-6nzXvTu8t0NUdg,1066
|
|
441
445
|
fusion_bench_config/README.md,sha256=Lc8YSBJ5oxf9KV5kKDivJ9LRyGuraGQPmBbgbdVA-j4,703
|
|
442
446
|
fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml,sha256=7IxLQoLRz-sRWyV8Vqc5kQcmYE_9YQz2_77pmvAkum8,1207
|
|
443
447
|
fusion_bench_config/fabric_model_fusion.yaml,sha256=YwJx_aUXm4ca4_mVItKVUOesMvmBBRGudQIOqgc1EP8,974
|
|
@@ -622,9 +626,10 @@ fusion_bench_config/method/randes/superposed_model_soup.yaml,sha256=7M9qV_wCgrE3
|
|
|
622
626
|
fusion_bench_config/method/randes/superposed_task_arithmetic.yaml,sha256=Pw0pZtwoMIPiqHfFNbN8wqNDyYb4L5p6fIOaaDSzJQg,498
|
|
623
627
|
fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml,sha256=xH8IkGnjvKLEWsms64toWhOrKIJG9dYfqQGOsVT4GDc,539
|
|
624
628
|
fusion_bench_config/method/rankone_moe/rankone_moe.yaml,sha256=rYas_GFFHvn3AgKNrI0Zp4ElL9e3SppGPrFAMa_u9r8,863
|
|
625
|
-
fusion_bench_config/method/regmean/clip_regmean.yaml,sha256=
|
|
629
|
+
fusion_bench_config/method/regmean/clip_regmean.yaml,sha256=QfkCHCLK9wbyB1Tq1S7YT3351MbWzOjUQiALE-EJBgw,426
|
|
626
630
|
fusion_bench_config/method/regmean/gpt2_regmean.yaml,sha256=n94aTboDdwSA7Tki8l_o8tYQkhXxPV8lRf-dRNPIsOs,422
|
|
627
631
|
fusion_bench_config/method/regmean/regmean.yaml,sha256=ZgVVLx-lHwVgjtjTl4VZUlthh8yyua87QvoJfmNHud4,101
|
|
632
|
+
fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml,sha256=A034ryEwvosqyQzA3KWs7kdp-3CUnoJtCujVywV-uzA,434
|
|
628
633
|
fusion_bench_config/method/slerp/slerp.yaml,sha256=xldDUULtfCdwzAkQUb0C8-TmbW7FqcAlIOsPX8p4n6w,116
|
|
629
634
|
fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml,sha256=ZMn_ImRjjc2uozf7ocQIzbgvFDpBV7S-34KptbBXVGo,200
|
|
630
635
|
fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml,sha256=VFMrkbO69d0wCjTQCuKysYGVe6hEwNu792g1QkhU5Mk,383
|
|
@@ -760,21 +765,21 @@ fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora
|
|
|
760
765
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_model_only.yaml,sha256=4_fQ7O6vUzIxLe-3mfY6qapx1rg5icQe_ODCbMspVRU,236
|
|
761
766
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14.yaml,sha256=jBFj9O84KGWOwAI8a_In3Cq_C2caNE7JPYxhKaVDjsE,508
|
|
762
767
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL14_model_only.yaml,sha256=fvebppOvcZkVQfsCRllZzEU3ifbes4y7PHqS_gy0SYY,384
|
|
763
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml,sha256=
|
|
764
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml,sha256=
|
|
768
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml,sha256=oPsGZjPxfb9wqmKzZv9MFeZNCtTh1AOY-01GK5fxMA4,2454
|
|
769
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml,sha256=pYoIPL2ebZGv7_dUQZQCDFVlU68K_VqUHQU0asYccmA,1384
|
|
765
770
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml,sha256=-jdB4ctj_NJcgbdYgog2-cighUSufAfpUvMOrGDblog,547
|
|
766
771
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml,sha256=si8oNdCzCE_UDsHevyPGXxcJ0IZfRWHrrykHryazIvY,435
|
|
767
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml,sha256=
|
|
768
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml,sha256=
|
|
769
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml,sha256=
|
|
770
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml,sha256=
|
|
771
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml,sha256=
|
|
772
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml,sha256=
|
|
773
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml,sha256=
|
|
774
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml,sha256=
|
|
775
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml,sha256=
|
|
776
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml,sha256=
|
|
777
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml,sha256=
|
|
772
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml,sha256=FUJYnr8pRVYEKoPLGrgGf6FH7ctbpCntG1YSYQV130Q,1013
|
|
773
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml,sha256=5nq8kWxu6rc5LJ83yVtzbTqaaLol_lymkbePVmGFo94,667
|
|
774
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml,sha256=vz1rmOT-bt56eDmWLT21NFq-Si_qE8EHcbWWi67VbGE,568
|
|
775
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml,sha256=rwbWzJnWmVNiwxfyNFDVc-5neIZToQm9N1VZLBt55Kk,1290
|
|
776
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml,sha256=vSiHAIRJI-p7V8u2BS45f20AfIPccjFZ2oImGUFvytg,1742
|
|
777
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml,sha256=VfrDzBIDRZcBXl1PwRiB04HcvwEARAIweHvlGGNhNmk,1902
|
|
778
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml,sha256=xh2lwb8sKMqdd6VZNj1lA2r063XXtQG0vUCXv76slNQ,1048
|
|
779
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml,sha256=qHUHlJCfYm1oxgb4gzaS0BBQ5akmdAO8g1kvO2FPxCw,2066
|
|
780
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml,sha256=383zCvjWgs5VmwBKWXcHfqqO8Fk75tD7IyhcvJ590ZU,2270
|
|
781
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml,sha256=YRv9FFqWnY4yOEtEAL91TRf2xgPB8P7Y0CAOgtX31Kk,2454
|
|
782
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml,sha256=gEHHb4LVHUteJ-MKlupbZXktff6_wvGVL0CLv4cfrk0,1385
|
|
778
783
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml,sha256=V93v7cjxF0ZPJj0wX76Q-hSNvolUaTtoeWuAImSU53g,524
|
|
779
784
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp1.yaml,sha256=2WtCV1cJEEK3R-t4Tf-YB1AIZl-d0FkE6C0CsUBm9fw,625
|
|
780
785
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_generalization_exp2.yaml,sha256=BmQ0JP8Oa5_V5pJ55nJS1xR-OIPmvySSqQ36l2jAB1w,625
|
|
@@ -782,20 +787,20 @@ fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individu
|
|
|
782
787
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_mtl.yaml,sha256=pQr8lF-hIFrPXPcZYlbSxx8PF8EClZ6nm2L4yqEmHTk,176
|
|
783
788
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_clean.yaml,sha256=7oQtoqXs37fctajb6E7UOB0GT515eEGzFNm93dWOKKk,509
|
|
784
789
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml,sha256=txMh1k0O3Spusqewp7zV0N0L9e2fg87lviDEnNJSHGQ,900
|
|
785
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml,sha256=
|
|
786
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml,sha256=
|
|
787
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml,sha256=
|
|
788
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml,sha256=
|
|
789
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml,sha256=
|
|
790
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml,sha256=
|
|
791
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml,sha256=
|
|
792
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml,sha256
|
|
793
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml,sha256=
|
|
790
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml,sha256=nm22fkXYUGAgZ39N2AzSE4CUVaODHgi8HTIxBWuTz3M,240
|
|
791
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml,sha256=i78xIL-vP28dYZaXntLsm7e9IdI2yAeUwZZze5fd9Do,288
|
|
792
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml,sha256=gDEzNfwsMtIu2xH8WSIUblx4ZyL1FsLXoSEpXPHiMaI,482
|
|
793
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml,sha256=Ej1NOsPJzLVZmBI9jK3QMhZg198IqThUZwt8_6GizUM,442
|
|
794
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml,sha256=V1KSVfjY0hqbJFnnOkoe7SD3qKBWmYyx2fqrLTX05bE,548
|
|
795
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml,sha256=jOTbHpyqBVepW7KhorStMjH9cVzHZvq6K7osgYnsVIU,443
|
|
796
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml,sha256=dFVylgWlqWWW_Hh0N7b1m5A8WYK-r-tO7TxEviR5dCY,382
|
|
797
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml,sha256=ZuH1oweRcl08OeXNIxsb68PxYwdrUfliaHYa4s64ljo,939
|
|
798
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml,sha256=LmILwWINU8cRVvxx3IKZPFcAzPyMnr8OlCJ0TDs1WMg,573
|
|
794
799
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14.yaml,sha256=TyF7CKXpBOiDoLtDVvZuBzPI9gEJo_c99j3Gwkl3FWw,510
|
|
795
800
|
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL14_model_only.yaml,sha256=i5XGxa2FoW2464R4k5AG-7r5qmzjHXkCSm1Om2cshik,386
|
|
796
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml,sha256=
|
|
797
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml,sha256=
|
|
798
|
-
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml,sha256=
|
|
801
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml,sha256=FuPWQbC9xEV5wZjuo835gOMNgbzmpK9RbjFjA_HOzqo,2476
|
|
802
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml,sha256=9PCkbrNnQSKTsm4eoUvVgjGd3IY7wHBC4LWj4kOdY4Y,1406
|
|
803
|
+
fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml,sha256=bqnyzgwIvDtV3Fb-uLf9mdFv0NW1C392lxGsGUPLsKE,400
|
|
799
804
|
fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml,sha256=HZXjqbZKpSZCHb-G8qjj03PcvXg_8mrAuewDHZp0oEw,263
|
|
800
805
|
fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml,sha256=8gr8ZtgegSHV0GHtJBiEgdYbRe8UHhO4_y8dayxZChk,506
|
|
801
806
|
fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml,sha256=oDsZkuAoh1mWUC7jZNzw8794zgX2bV5Z0esXpvbTs-c,643
|
|
@@ -877,8 +882,8 @@ fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml,sha256=3q-KMuFaM
|
|
|
877
882
|
fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml,sha256=GjpiiRownrBCpl-TNwWRW2PYePbF-Cl99jlLNPrK5T4,1017
|
|
878
883
|
fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml,sha256=WwiYMQKehtJixDPnu5o3vcWe4yJksXTWRqOzm3uVWXQ,1017
|
|
879
884
|
fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml,sha256=xGRt0J9joXTzWUew6DvoYprAWlPXhaVFw5AX4im5VQw,1017
|
|
880
|
-
fusion_bench-0.2.
|
|
881
|
-
fusion_bench-0.2.
|
|
882
|
-
fusion_bench-0.2.
|
|
883
|
-
fusion_bench-0.2.
|
|
884
|
-
fusion_bench-0.2.
|
|
885
|
+
fusion_bench-0.2.20.dist-info/METADATA,sha256=fa8CMvPeD8fg_R8YuvjM_AggXMAVsqg63fZy3I4TMmc,23634
|
|
886
|
+
fusion_bench-0.2.20.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
887
|
+
fusion_bench-0.2.20.dist-info/entry_points.txt,sha256=iUQ8MCJvda7HP4vYh2n1Teoapb4G9PBVYZkAfcc5SHU,116
|
|
888
|
+
fusion_bench-0.2.20.dist-info/top_level.txt,sha256=BuO4TL6iHL_2yPBUX9-LlIrHRczA_BNMIFwweK0PQEI,13
|
|
889
|
+
fusion_bench-0.2.20.dist-info/RECORD,,
|
|
@@ -5,7 +5,7 @@ exclude_param_names_regex: []
|
|
|
5
5
|
num_regmean_examples: 256
|
|
6
6
|
weight_transpose: true
|
|
7
7
|
# float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
8
|
-
reduce_non_diagonal_ratio: 0.
|
|
8
|
+
reduce_non_diagonal_ratio: 0.95
|
|
9
9
|
dataloader_kwargs:
|
|
10
10
|
batch_size: 32
|
|
11
11
|
num_workers: 0
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
_target_: fusion_bench.method.RegMeanAlgorithmForCLIPPlusPlus
|
|
2
|
+
# list, regular expression of names of parameters that need to be excluded
|
|
3
|
+
exclude_param_names_regex: []
|
|
4
|
+
# numbers of examples to compute regmean weights
|
|
5
|
+
num_regmean_examples: 256
|
|
6
|
+
weight_transpose: true
|
|
7
|
+
# float, reduce non-diagonal elements in regmean weights by multiplying this scalar
|
|
8
|
+
reduce_non_diagonal_ratio: 0.95
|
|
9
|
+
dataloader_kwargs:
|
|
10
|
+
batch_size: 32
|
|
11
|
+
num_workers: 0
|