fusion-bench 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/constants/__init__.py +5 -1
- fusion_bench/constants/runtime.py +111 -7
- fusion_bench/dataset/gsm8k.py +6 -2
- fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
- fusion_bench/method/__init__.py +1 -1
- fusion_bench/method/classification/image_classification_finetune.py +13 -2
- fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
- fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
- fusion_bench/metrics/nyuv2/__init__.py +31 -0
- fusion_bench/metrics/nyuv2/depth.py +30 -0
- fusion_bench/metrics/nyuv2/loss.py +40 -0
- fusion_bench/metrics/nyuv2/noise.py +24 -0
- fusion_bench/metrics/nyuv2/normal.py +34 -1
- fusion_bench/metrics/nyuv2/segmentation.py +35 -1
- fusion_bench/mixins/clip_classification.py +30 -2
- fusion_bench/mixins/lightning_fabric.py +46 -5
- fusion_bench/mixins/rich_live.py +76 -0
- fusion_bench/modelpool/__init__.py +24 -2
- fusion_bench/modelpool/base_pool.py +94 -6
- fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
- fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
- fusion_bench/modelpool/resnet_for_image_classification.py +4 -1
- fusion_bench/models/model_card_templates/default.md +1 -1
- fusion_bench/scripts/webui.py +250 -17
- fusion_bench/utils/__init__.py +14 -0
- fusion_bench/utils/data.py +100 -9
- fusion_bench/utils/fabric.py +185 -4
- fusion_bench/utils/json.py +55 -8
- fusion_bench/utils/validation.py +197 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/METADATA +66 -7
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/RECORD +44 -40
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
- fusion_bench_config/llama_full_finetune.yaml +4 -16
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
- fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
- fusion_bench_config/nyuv2_config.yaml +4 -13
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench/utils/auto.py +0 -31
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/top_level.txt +0 -0
fusion_bench/utils/fabric.py
CHANGED
|
@@ -1,24 +1,205 @@
|
|
|
1
1
|
import time
|
|
2
|
-
from
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from functools import wraps
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
|
3
5
|
|
|
4
6
|
import lightning as L
|
|
7
|
+
import torch
|
|
5
8
|
|
|
6
9
|
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
7
10
|
|
|
8
11
|
log = get_rankzero_logger(__name__)
|
|
9
12
|
|
|
13
|
+
T = TypeVar("T")
|
|
10
14
|
|
|
11
|
-
|
|
15
|
+
|
|
16
|
+
def seed_everything_by_time(fabric: Optional[L.Fabric] = None) -> int:
|
|
12
17
|
"""
|
|
13
|
-
Set seed for all processes
|
|
18
|
+
Set seed for all processes based on current timestamp.
|
|
19
|
+
|
|
20
|
+
This function generates a time-based seed on the global zero process and broadcasts
|
|
21
|
+
it to all other processes in a distributed setting to ensure reproducibility across
|
|
22
|
+
all workers. When no fabric instance is provided, it generates a seed locally without
|
|
23
|
+
synchronization.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
fabric: Optional Lightning Fabric instance for distributed synchronization.
|
|
27
|
+
If None, seed is generated locally without broadcasting.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
The seed value used for random number generation.
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
```python
|
|
34
|
+
import lightning as L
|
|
35
|
+
from fusion_bench.utils.fabric import seed_everything_by_time
|
|
36
|
+
|
|
37
|
+
# With fabric (distributed)
|
|
38
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
39
|
+
fabric.launch()
|
|
40
|
+
seed = seed_everything_by_time(fabric)
|
|
41
|
+
print(f"All processes using seed: {seed}")
|
|
42
|
+
|
|
43
|
+
# Without fabric (single process)
|
|
44
|
+
seed = seed_everything_by_time()
|
|
45
|
+
print(f"Using seed: {seed}")
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Note:
|
|
49
|
+
- In distributed settings, only the global zero process generates the seed
|
|
50
|
+
- All other processes receive the broadcasted seed for consistency
|
|
51
|
+
- The seed is based on `time.time()`, so it will differ across runs
|
|
14
52
|
"""
|
|
15
|
-
#
|
|
53
|
+
# Generate seed on global zero process, None on others
|
|
16
54
|
if fabric is None or fabric.is_global_zero:
|
|
17
55
|
seed = int(time.time())
|
|
56
|
+
log.info(f"Generated time-based seed: {seed}")
|
|
18
57
|
else:
|
|
19
58
|
seed = None
|
|
59
|
+
|
|
60
|
+
# Broadcast seed to all processes in distributed setting
|
|
20
61
|
if fabric is not None:
|
|
21
62
|
log.debug(f"Broadcasting seed `{seed}` to all processes")
|
|
22
63
|
fabric.barrier()
|
|
23
64
|
seed = fabric.broadcast(seed, src=0)
|
|
65
|
+
|
|
66
|
+
# Apply seed to all random number generators
|
|
24
67
|
L.seed_everything(seed)
|
|
68
|
+
return seed
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def is_distributed(fabric: Optional[L.Fabric] = None) -> bool:
|
|
72
|
+
"""
|
|
73
|
+
Check if running in distributed mode (multi-process).
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
fabric: Optional Lightning Fabric instance. If None, returns False.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
True if running with multiple processes, False otherwise.
|
|
80
|
+
|
|
81
|
+
Example:
|
|
82
|
+
```python
|
|
83
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
84
|
+
fabric.launch()
|
|
85
|
+
if is_distributed(fabric):
|
|
86
|
+
print("Running in distributed mode")
|
|
87
|
+
```
|
|
88
|
+
"""
|
|
89
|
+
return fabric is not None and fabric.world_size > 1
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def get_world_info(fabric: Optional[L.Fabric] = None) -> Dict[str, Any]:
|
|
93
|
+
"""
|
|
94
|
+
Get comprehensive information about the distributed setup.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
fabric: Optional Lightning Fabric instance.
|
|
98
|
+
|
|
99
|
+
Returns:
|
|
100
|
+
Dictionary containing:
|
|
101
|
+
- world_size: Total number of processes
|
|
102
|
+
- global_rank: Global rank of current process
|
|
103
|
+
- local_rank: Local rank on current node
|
|
104
|
+
- is_global_zero: Whether this is the main process
|
|
105
|
+
- is_distributed: Whether running in distributed mode
|
|
106
|
+
|
|
107
|
+
Example:
|
|
108
|
+
```python
|
|
109
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
110
|
+
fabric.launch()
|
|
111
|
+
info = get_world_info(fabric)
|
|
112
|
+
print(f"Process {info['global_rank']}/{info['world_size']}")
|
|
113
|
+
```
|
|
114
|
+
"""
|
|
115
|
+
if fabric is None:
|
|
116
|
+
return {
|
|
117
|
+
"world_size": 1,
|
|
118
|
+
"global_rank": 0,
|
|
119
|
+
"local_rank": 0,
|
|
120
|
+
"is_global_zero": True,
|
|
121
|
+
"is_distributed": False,
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
return {
|
|
125
|
+
"world_size": fabric.world_size,
|
|
126
|
+
"global_rank": fabric.global_rank,
|
|
127
|
+
"local_rank": fabric.local_rank,
|
|
128
|
+
"is_global_zero": fabric.is_global_zero,
|
|
129
|
+
"is_distributed": fabric.world_size > 1,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def wait_for_everyone(
|
|
134
|
+
fabric: Optional[L.Fabric] = None, message: Optional[str] = None
|
|
135
|
+
) -> None:
|
|
136
|
+
"""
|
|
137
|
+
Synchronize all processes with optional logging.
|
|
138
|
+
|
|
139
|
+
This is a wrapper around fabric.barrier() with optional message logging.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
fabric: Optional Lightning Fabric instance. If None, does nothing.
|
|
143
|
+
message: Optional message to log before synchronization.
|
|
144
|
+
|
|
145
|
+
Example:
|
|
146
|
+
```python
|
|
147
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
148
|
+
fabric.launch()
|
|
149
|
+
|
|
150
|
+
# Do some work...
|
|
151
|
+
wait_for_everyone(fabric, "Waiting after model loading")
|
|
152
|
+
# All processes synchronized
|
|
153
|
+
```
|
|
154
|
+
"""
|
|
155
|
+
if fabric is not None:
|
|
156
|
+
if message and fabric.is_global_zero:
|
|
157
|
+
log.info(message)
|
|
158
|
+
fabric.barrier()
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
@contextmanager
|
|
162
|
+
def rank_zero_only_context(fabric: Optional[L.Fabric] = None):
|
|
163
|
+
"""
|
|
164
|
+
Context manager to execute code block only on global rank 0.
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
fabric: Optional Lightning Fabric instance.
|
|
168
|
+
|
|
169
|
+
Example:
|
|
170
|
+
```python
|
|
171
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
172
|
+
fabric.launch()
|
|
173
|
+
|
|
174
|
+
with rank_zero_only_context(fabric):
|
|
175
|
+
print("This prints only on rank 0")
|
|
176
|
+
save_checkpoint(model, "checkpoint.pt")
|
|
177
|
+
```
|
|
178
|
+
"""
|
|
179
|
+
should_execute = fabric is None or fabric.is_global_zero
|
|
180
|
+
try:
|
|
181
|
+
yield should_execute
|
|
182
|
+
finally:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def print_on_rank_zero(*args, fabric: Optional[L.Fabric] = None, **kwargs) -> None:
|
|
187
|
+
"""
|
|
188
|
+
Print message only on global rank 0.
|
|
189
|
+
|
|
190
|
+
Args:
|
|
191
|
+
*args: Arguments to pass to print().
|
|
192
|
+
fabric: Optional Lightning Fabric instance.
|
|
193
|
+
**kwargs: Keyword arguments to pass to print().
|
|
194
|
+
|
|
195
|
+
Example:
|
|
196
|
+
```python
|
|
197
|
+
fabric = L.Fabric(accelerator="auto", devices=2)
|
|
198
|
+
fabric.launch()
|
|
199
|
+
|
|
200
|
+
print_on_rank_zero("Starting training", fabric=fabric)
|
|
201
|
+
# Prints only on rank 0
|
|
202
|
+
```
|
|
203
|
+
"""
|
|
204
|
+
if fabric is None or fabric.is_global_zero:
|
|
205
|
+
print(*args, **kwargs)
|
fusion_bench/utils/json.py
CHANGED
|
@@ -1,31 +1,78 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from pathlib import Path
|
|
3
|
-
from typing import Any, Union
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Union
|
|
4
4
|
|
|
5
|
+
from fusion_bench.utils.validation import validate_file_exists
|
|
5
6
|
|
|
6
|
-
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from pyarrow.fs import FileSystem
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def save_to_json(obj, path: Union[str, Path], filesystem: "FileSystem" = None):
|
|
7
12
|
"""
|
|
8
13
|
save an object to a json file
|
|
9
14
|
|
|
10
15
|
Args:
|
|
11
16
|
obj (Any): the object to save
|
|
12
17
|
path (Union[str, Path]): the path to save the object
|
|
18
|
+
filesystem (FileSystem, optional): PyArrow FileSystem to use for writing.
|
|
19
|
+
If None, uses local filesystem via standard Python open().
|
|
20
|
+
Can also be an s3fs.S3FileSystem or fsspec filesystem.
|
|
13
21
|
"""
|
|
14
|
-
|
|
15
|
-
json.
|
|
22
|
+
if filesystem is not None:
|
|
23
|
+
json_str = json.dumps(obj)
|
|
24
|
+
# Check if it's an fsspec-based filesystem (like s3fs)
|
|
25
|
+
if hasattr(filesystem, "open"):
|
|
26
|
+
# Direct fsspec/s3fs usage - more reliable for some endpoints
|
|
27
|
+
path_str = str(path)
|
|
28
|
+
with filesystem.open(path_str, "w") as f:
|
|
29
|
+
f.write(json_str)
|
|
30
|
+
else:
|
|
31
|
+
# Use PyArrow filesystem
|
|
32
|
+
path_str = str(path)
|
|
33
|
+
with filesystem.open_output_stream(path_str) as f:
|
|
34
|
+
f.write(json_str.encode("utf-8"))
|
|
35
|
+
else:
|
|
36
|
+
# Use standard Python file operations
|
|
37
|
+
with open(path, "w") as f:
|
|
38
|
+
json.dump(obj, f)
|
|
16
39
|
|
|
17
40
|
|
|
18
|
-
def load_from_json(
|
|
41
|
+
def load_from_json(
|
|
42
|
+
path: Union[str, Path], filesystem: "FileSystem" = None
|
|
43
|
+
) -> Union[dict, list]:
|
|
19
44
|
"""load an object from a json file
|
|
20
45
|
|
|
21
46
|
Args:
|
|
22
47
|
path (Union[str, Path]): the path to load the object
|
|
48
|
+
filesystem (FileSystem, optional): PyArrow FileSystem to use for reading.
|
|
49
|
+
If None, uses local filesystem via standard Python open().
|
|
50
|
+
Can also be an s3fs.S3FileSystem or fsspec filesystem.
|
|
23
51
|
|
|
24
52
|
Returns:
|
|
25
|
-
dict: the loaded object
|
|
53
|
+
Union[dict, list]: the loaded object
|
|
54
|
+
|
|
55
|
+
Raises:
|
|
56
|
+
ValidationError: If the file doesn't exist (when using local filesystem)
|
|
26
57
|
"""
|
|
27
|
-
|
|
28
|
-
|
|
58
|
+
if filesystem is not None:
|
|
59
|
+
# Check if it's an fsspec-based filesystem (like s3fs)
|
|
60
|
+
if hasattr(filesystem, "open"):
|
|
61
|
+
# Direct fsspec/s3fs usage
|
|
62
|
+
path_str = str(path)
|
|
63
|
+
with filesystem.open(path_str, "r") as f:
|
|
64
|
+
return json.load(f)
|
|
65
|
+
else:
|
|
66
|
+
# Use PyArrow filesystem
|
|
67
|
+
path_str = str(path)
|
|
68
|
+
with filesystem.open_input_stream(path_str) as f:
|
|
69
|
+
json_data = f.read().decode("utf-8")
|
|
70
|
+
return json.loads(json_data)
|
|
71
|
+
else:
|
|
72
|
+
# Use standard Python file operations
|
|
73
|
+
validate_file_exists(path)
|
|
74
|
+
with open(path, "r") as f:
|
|
75
|
+
return json.load(f)
|
|
29
76
|
|
|
30
77
|
|
|
31
78
|
def _is_list_of_dict(obj) -> bool:
|
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Validation utilities for FusionBench.
|
|
3
|
+
|
|
4
|
+
This module provides robust input validation functions to ensure data integrity
|
|
5
|
+
and provide clear error messages throughout the FusionBench framework.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Optional, Union
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"ValidationError",
|
|
14
|
+
"validate_path_exists",
|
|
15
|
+
"validate_file_exists",
|
|
16
|
+
"validate_directory_exists",
|
|
17
|
+
"validate_model_name",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
log = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class ValidationError(ValueError):
|
|
24
|
+
"""Custom exception for validation errors with detailed context."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, message: str, field: Optional[str] = None, value: Any = None):
|
|
27
|
+
self.field = field
|
|
28
|
+
self.value = value
|
|
29
|
+
detailed_message = message
|
|
30
|
+
if field:
|
|
31
|
+
detailed_message = f"Validation error for '{field}': {message}"
|
|
32
|
+
if value is not None:
|
|
33
|
+
detailed_message += f" (got: {value!r})"
|
|
34
|
+
super().__init__(detailed_message)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def validate_path_exists(
|
|
38
|
+
path: Union[str, Path],
|
|
39
|
+
name: str = "path",
|
|
40
|
+
create_if_missing: bool = False,
|
|
41
|
+
must_be_file: bool = False,
|
|
42
|
+
must_be_dir: bool = False,
|
|
43
|
+
) -> Path:
|
|
44
|
+
"""
|
|
45
|
+
Validate that a path exists and optionally check its type.
|
|
46
|
+
|
|
47
|
+
Args:
|
|
48
|
+
path: Path to validate.
|
|
49
|
+
name: Name of the path for error messages.
|
|
50
|
+
create_if_missing: If True and path doesn't exist, create it as a directory.
|
|
51
|
+
must_be_file: If True, ensure path points to a file.
|
|
52
|
+
must_be_dir: If True, ensure path points to a directory.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Path object of the validated path.
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
ValidationError: If path validation fails.
|
|
59
|
+
|
|
60
|
+
Examples:
|
|
61
|
+
>>> validate_path_exists("./config", name="config_dir", must_be_dir=True)
|
|
62
|
+
PosixPath('config')
|
|
63
|
+
"""
|
|
64
|
+
if path is None:
|
|
65
|
+
raise ValidationError(f"{name} cannot be None", field=name, value=path)
|
|
66
|
+
|
|
67
|
+
assert not (
|
|
68
|
+
create_if_missing and must_be_file
|
|
69
|
+
), "create_if_missing and must_be_file cannot both be True. By definition, a created path is a directory."
|
|
70
|
+
|
|
71
|
+
path_obj = Path(path).expanduser().resolve()
|
|
72
|
+
|
|
73
|
+
if not path_obj.exists():
|
|
74
|
+
if create_if_missing:
|
|
75
|
+
log.info(f"Creating missing directory: {path_obj}")
|
|
76
|
+
path_obj.mkdir(parents=True, exist_ok=True)
|
|
77
|
+
else:
|
|
78
|
+
raise ValidationError(
|
|
79
|
+
f"{name} does not exist: {path_obj}", field=name, value=str(path)
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if must_be_file and not path_obj.is_file():
|
|
83
|
+
raise ValidationError(
|
|
84
|
+
f"{name} must be a file, but got directory: {path_obj}",
|
|
85
|
+
field=name,
|
|
86
|
+
value=str(path),
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if must_be_dir and not path_obj.is_dir():
|
|
90
|
+
raise ValidationError(
|
|
91
|
+
f"{name} must be a directory, but got file: {path_obj}",
|
|
92
|
+
field=name,
|
|
93
|
+
value=str(path),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return path_obj
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def validate_file_exists(path: Union[str, Path], name: str = "file") -> Path:
|
|
100
|
+
"""
|
|
101
|
+
Validate that a file exists.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
path: File path to validate.
|
|
105
|
+
name: Name of the file for error messages.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
Path object of the validated file.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
ValidationError: If file doesn't exist or is not a file.
|
|
112
|
+
"""
|
|
113
|
+
return validate_path_exists(path, name=name, must_be_file=True)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def validate_directory_exists(
|
|
117
|
+
path: Union[str, Path], name: str = "directory", create_if_missing: bool = False
|
|
118
|
+
) -> Path:
|
|
119
|
+
"""
|
|
120
|
+
Validate that a directory exists.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
path: Directory path to validate.
|
|
124
|
+
name: Name of the directory for error messages.
|
|
125
|
+
create_if_missing: If True, create directory if it doesn't exist.
|
|
126
|
+
|
|
127
|
+
Returns:
|
|
128
|
+
Path object of the validated directory.
|
|
129
|
+
|
|
130
|
+
Raises:
|
|
131
|
+
ValidationError: If directory doesn't exist (and not creating) or is not a directory.
|
|
132
|
+
"""
|
|
133
|
+
return validate_path_exists(
|
|
134
|
+
path, name=name, must_be_dir=True, create_if_missing=create_if_missing
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def validate_model_name(
|
|
139
|
+
model_name: str, allow_special: bool = True, field: str = "model_name"
|
|
140
|
+
) -> str:
|
|
141
|
+
"""
|
|
142
|
+
Validate a model name string.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
model_name: Model name to validate.
|
|
146
|
+
allow_special: If True, allow special names like "_pretrained_". If False,
|
|
147
|
+
names starting and ending with underscores will be rejected.
|
|
148
|
+
field: Field name for error messages.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
The validated model name.
|
|
152
|
+
|
|
153
|
+
Raises:
|
|
154
|
+
ValidationError: If model name is invalid.
|
|
155
|
+
|
|
156
|
+
Examples:
|
|
157
|
+
>>> validate_model_name("openai/clip-vit-base-patch32")
|
|
158
|
+
'openai/clip-vit-base-patch32'
|
|
159
|
+
>>> validate_model_name("_pretrained_", allow_special=True)
|
|
160
|
+
'_pretrained_'
|
|
161
|
+
>>> validate_model_name("_pretrained_", allow_special=False)
|
|
162
|
+
Traceback (most recent call last):
|
|
163
|
+
...
|
|
164
|
+
ValidationError: Validation error for 'model_name': Special model names (starting and ending with '_') are not allowed (got: '_pretrained_')
|
|
165
|
+
"""
|
|
166
|
+
if not model_name or not isinstance(model_name, str):
|
|
167
|
+
raise ValidationError(
|
|
168
|
+
"Model name must be a non-empty string", field=field, value=model_name
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
model_name = model_name.strip()
|
|
172
|
+
if not model_name:
|
|
173
|
+
raise ValidationError(
|
|
174
|
+
"Model name cannot be empty or whitespace only",
|
|
175
|
+
field=field,
|
|
176
|
+
value=model_name,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Check for special names (e.g., _pretrained_, _base_model_)
|
|
180
|
+
if not allow_special and model_name.startswith("_") and model_name.endswith("_"):
|
|
181
|
+
raise ValidationError(
|
|
182
|
+
"Special model names (starting and ending with '_') are not allowed",
|
|
183
|
+
field=field,
|
|
184
|
+
value=model_name,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Check for invalid characters that might cause issues
|
|
188
|
+
invalid_chars = ["\n", "\r", "\t", "\0"]
|
|
189
|
+
for char in invalid_chars:
|
|
190
|
+
if char in model_name:
|
|
191
|
+
raise ValidationError(
|
|
192
|
+
f"Model name contains invalid character: {char!r}",
|
|
193
|
+
field=field,
|
|
194
|
+
value=model_name,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return model_name
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion-bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.29
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
Project-URL: Repository, https://github.com/tanganke/fusion_bench
|
|
@@ -44,12 +44,14 @@ Dynamic: license-file
|
|
|
44
44
|
|
|
45
45
|
[](http://arxiv.org/abs/2406.03280)
|
|
46
46
|
[](https://github.com/tanganke/fusion_bench/blob/main/LICENSE)
|
|
47
|
-
[](https://pypi.org/project/fusion-bench/)
|
|
48
|
-
[](https://pepy.tech/project/fusion-bench)
|
|
49
47
|
[](https://tanganke.github.io/fusion_bench/)
|
|
50
48
|
[](https://github.com/psf/black)
|
|
51
49
|
[](https://github.com/google/yamlfmt)
|
|
52
50
|
|
|
51
|
+
[](https://www.codefactor.io/repository/github/tanganke/fusion_bench/overview/main)
|
|
52
|
+
[](https://pypi.org/project/fusion-bench/)
|
|
53
|
+
[](https://pepy.tech/project/fusion-bench)
|
|
54
|
+
|
|
53
55
|
</div>
|
|
54
56
|
|
|
55
57
|
> [!TIP]
|
|
@@ -205,6 +207,48 @@ The CLI's design allows for easy extension to new fusion methods, model types, a
|
|
|
205
207
|
|
|
206
208
|
Read the [CLI documentation](https://tanganke.github.io/fusion_bench/cli/fusion_bench/) for more information.
|
|
207
209
|
|
|
210
|
+
## The FusionBench Workflow
|
|
211
|
+
|
|
212
|
+
FusionBench follows a three-component architecture to perform model fusion experiments:
|
|
213
|
+
|
|
214
|
+
```mermaid
|
|
215
|
+
graph LR
|
|
216
|
+
CLI[fusion_bench CLI] --> Hydra[Hydra Config]
|
|
217
|
+
Hydra --> Program[Program]
|
|
218
|
+
|
|
219
|
+
Program --> MP[ModelPool<br/>Manages Models<br/>& Datasets]
|
|
220
|
+
Program --> Method[Method<br/>Fusion Algorithm]
|
|
221
|
+
Program --> TP[TaskPool<br/>Evaluation Tasks]
|
|
222
|
+
|
|
223
|
+
MP --> Method
|
|
224
|
+
Method --> Merged[Merged Model]
|
|
225
|
+
Merged --> TP
|
|
226
|
+
TP --> Report[Evaluation Report]
|
|
227
|
+
|
|
228
|
+
style CLI fill:#e1f5e1
|
|
229
|
+
style Hydra fill:#f0e1ff
|
|
230
|
+
style Method fill:#ffe1f0
|
|
231
|
+
style Merged fill:#fff4e1
|
|
232
|
+
style Report fill:#e1f0ff
|
|
233
|
+
```
|
|
234
|
+
|
|
235
|
+
**Key Components:**
|
|
236
|
+
|
|
237
|
+
1. **CLI**: Entry point using Hydra for configuration management
|
|
238
|
+
2. **Program**: Orchestrates the fusion workflow (e.g., `FabricModelFusionProgram`)
|
|
239
|
+
3. **ModelPool**: Manages task-specific models and their datasets
|
|
240
|
+
4. **Method**: Implements the fusion algorithm (e.g., Simple Average, Task Arithmetic, AdaMerging)
|
|
241
|
+
5. **TaskPool**: Evaluates the merged model on benchmark tasks
|
|
242
|
+
|
|
243
|
+
**Workflow Steps:**
|
|
244
|
+
|
|
245
|
+
1. User runs `fusion_bench` with config overrides
|
|
246
|
+
2. Hydra loads YAML configs for method, modelpool, and taskpool
|
|
247
|
+
3. Program instantiates all three components
|
|
248
|
+
4. Method executes fusion algorithm on ModelPool
|
|
249
|
+
5. TaskPool evaluates the merged model
|
|
250
|
+
6. Results are saved and reported
|
|
251
|
+
|
|
208
252
|
## Implement your own model fusion algorithm
|
|
209
253
|
|
|
210
254
|
First, create a new Python file for the algorithm in the `fusion_bench/method` directory.
|
|
@@ -272,11 +316,24 @@ Click on [<kbd>Use this template</kbd>](https://github.com/fusion-bench/fusion-b
|
|
|
272
316
|
|
|
273
317
|
</div>
|
|
274
318
|
|
|
275
|
-
### FusionBench Command Generator WebUI
|
|
319
|
+
### FusionBench Command Generator WebUI
|
|
320
|
+
|
|
321
|
+
> [!NOTE]
|
|
322
|
+
> Requires `gradio` package. Install with `pip install gradio`.
|
|
323
|
+
|
|
324
|
+
For users who prefer a graphical interface, FusionBench provides an interactive web UI for generating commands:
|
|
325
|
+
|
|
326
|
+
```bash
|
|
327
|
+
fusion_bench_webui
|
|
328
|
+
```
|
|
329
|
+
|
|
330
|
+
This launches a browser-based interface where you can:
|
|
276
331
|
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
332
|
+
- Select root configurations and components through dropdowns
|
|
333
|
+
- Adjust hyperparameters interactively
|
|
334
|
+
- View real-time YAML configuration updates
|
|
335
|
+
|
|
336
|
+
The WebUI is particularly useful for exploring available configurations, experimenting with different parameter combinations, and learning the FusionBench configuration structure. [Learn more about the WebUI](https://tanganke.github.io/fusion_bench/cli/fusion_bench_webui/).
|
|
280
337
|
|
|
281
338
|

|
|
282
339
|
|
|
@@ -296,3 +353,5 @@ If you find this benchmark useful, please consider citing our work:
|
|
|
296
353
|
## Star History
|
|
297
354
|
|
|
298
355
|
[](https://www.star-history.com/#tanganke/fusion_bench&Date)
|
|
356
|
+
|
|
357
|
+

|