fusion-bench 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/constants/__init__.py +5 -1
  3. fusion_bench/constants/runtime.py +111 -7
  4. fusion_bench/dataset/gsm8k.py +6 -2
  5. fusion_bench/dataset/image_corruption/make_corruption.py +168 -0
  6. fusion_bench/method/__init__.py +1 -1
  7. fusion_bench/method/classification/image_classification_finetune.py +13 -2
  8. fusion_bench/method/gossip/clip_task_wise_gossip.py +1 -29
  9. fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
  10. fusion_bench/metrics/nyuv2/__init__.py +31 -0
  11. fusion_bench/metrics/nyuv2/depth.py +30 -0
  12. fusion_bench/metrics/nyuv2/loss.py +40 -0
  13. fusion_bench/metrics/nyuv2/noise.py +24 -0
  14. fusion_bench/metrics/nyuv2/normal.py +34 -1
  15. fusion_bench/metrics/nyuv2/segmentation.py +35 -1
  16. fusion_bench/mixins/clip_classification.py +30 -2
  17. fusion_bench/mixins/lightning_fabric.py +46 -5
  18. fusion_bench/mixins/rich_live.py +76 -0
  19. fusion_bench/modelpool/__init__.py +24 -2
  20. fusion_bench/modelpool/base_pool.py +94 -6
  21. fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
  22. fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
  23. fusion_bench/modelpool/resnet_for_image_classification.py +4 -1
  24. fusion_bench/models/model_card_templates/default.md +1 -1
  25. fusion_bench/scripts/webui.py +250 -17
  26. fusion_bench/utils/__init__.py +14 -0
  27. fusion_bench/utils/data.py +100 -9
  28. fusion_bench/utils/fabric.py +185 -4
  29. fusion_bench/utils/json.py +55 -8
  30. fusion_bench/utils/validation.py +197 -0
  31. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/METADATA +66 -7
  32. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/RECORD +44 -40
  33. fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +6 -19
  34. fusion_bench_config/llama_full_finetune.yaml +4 -16
  35. fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  36. fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
  37. fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
  38. fusion_bench_config/nyuv2_config.yaml +4 -13
  39. fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  40. fusion_bench_config/taskpool/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
  41. fusion_bench/utils/auto.py +0 -31
  42. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/WHEEL +0 -0
  43. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/entry_points.txt +0 -0
  44. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/licenses/LICENSE +0 -0
  45. {fusion_bench-0.2.27.dist-info → fusion_bench-0.2.29.dist-info}/top_level.txt +0 -0
@@ -1,24 +1,205 @@
1
1
  import time
2
- from typing import Optional
2
+ from contextlib import contextmanager
3
+ from functools import wraps
4
+ from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
3
5
 
4
6
  import lightning as L
7
+ import torch
5
8
 
6
9
  from fusion_bench.utils.pylogger import get_rankzero_logger
7
10
 
8
11
  log = get_rankzero_logger(__name__)
9
12
 
13
+ T = TypeVar("T")
10
14
 
11
- def seed_everything_by_time(fabric: Optional[L.Fabric] = None):
15
+
16
+ def seed_everything_by_time(fabric: Optional[L.Fabric] = None) -> int:
12
17
  """
13
- Set seed for all processes by time.
18
+ Set seed for all processes based on current timestamp.
19
+
20
+ This function generates a time-based seed on the global zero process and broadcasts
21
+ it to all other processes in a distributed setting to ensure reproducibility across
22
+ all workers. When no fabric instance is provided, it generates a seed locally without
23
+ synchronization.
24
+
25
+ Args:
26
+ fabric: Optional Lightning Fabric instance for distributed synchronization.
27
+ If None, seed is generated locally without broadcasting.
28
+
29
+ Returns:
30
+ The seed value used for random number generation.
31
+
32
+ Example:
33
+ ```python
34
+ import lightning as L
35
+ from fusion_bench.utils.fabric import seed_everything_by_time
36
+
37
+ # With fabric (distributed)
38
+ fabric = L.Fabric(accelerator="auto", devices=2)
39
+ fabric.launch()
40
+ seed = seed_everything_by_time(fabric)
41
+ print(f"All processes using seed: {seed}")
42
+
43
+ # Without fabric (single process)
44
+ seed = seed_everything_by_time()
45
+ print(f"Using seed: {seed}")
46
+ ```
47
+
48
+ Note:
49
+ - In distributed settings, only the global zero process generates the seed
50
+ - All other processes receive the broadcasted seed for consistency
51
+ - The seed is based on `time.time()`, so it will differ across runs
14
52
  """
15
- # set seed for all processes
53
+ # Generate seed on global zero process, None on others
16
54
  if fabric is None or fabric.is_global_zero:
17
55
  seed = int(time.time())
56
+ log.info(f"Generated time-based seed: {seed}")
18
57
  else:
19
58
  seed = None
59
+
60
+ # Broadcast seed to all processes in distributed setting
20
61
  if fabric is not None:
21
62
  log.debug(f"Broadcasting seed `{seed}` to all processes")
22
63
  fabric.barrier()
23
64
  seed = fabric.broadcast(seed, src=0)
65
+
66
+ # Apply seed to all random number generators
24
67
  L.seed_everything(seed)
68
+ return seed
69
+
70
+
71
+ def is_distributed(fabric: Optional[L.Fabric] = None) -> bool:
72
+ """
73
+ Check if running in distributed mode (multi-process).
74
+
75
+ Args:
76
+ fabric: Optional Lightning Fabric instance. If None, returns False.
77
+
78
+ Returns:
79
+ True if running with multiple processes, False otherwise.
80
+
81
+ Example:
82
+ ```python
83
+ fabric = L.Fabric(accelerator="auto", devices=2)
84
+ fabric.launch()
85
+ if is_distributed(fabric):
86
+ print("Running in distributed mode")
87
+ ```
88
+ """
89
+ return fabric is not None and fabric.world_size > 1
90
+
91
+
92
+ def get_world_info(fabric: Optional[L.Fabric] = None) -> Dict[str, Any]:
93
+ """
94
+ Get comprehensive information about the distributed setup.
95
+
96
+ Args:
97
+ fabric: Optional Lightning Fabric instance.
98
+
99
+ Returns:
100
+ Dictionary containing:
101
+ - world_size: Total number of processes
102
+ - global_rank: Global rank of current process
103
+ - local_rank: Local rank on current node
104
+ - is_global_zero: Whether this is the main process
105
+ - is_distributed: Whether running in distributed mode
106
+
107
+ Example:
108
+ ```python
109
+ fabric = L.Fabric(accelerator="auto", devices=2)
110
+ fabric.launch()
111
+ info = get_world_info(fabric)
112
+ print(f"Process {info['global_rank']}/{info['world_size']}")
113
+ ```
114
+ """
115
+ if fabric is None:
116
+ return {
117
+ "world_size": 1,
118
+ "global_rank": 0,
119
+ "local_rank": 0,
120
+ "is_global_zero": True,
121
+ "is_distributed": False,
122
+ }
123
+
124
+ return {
125
+ "world_size": fabric.world_size,
126
+ "global_rank": fabric.global_rank,
127
+ "local_rank": fabric.local_rank,
128
+ "is_global_zero": fabric.is_global_zero,
129
+ "is_distributed": fabric.world_size > 1,
130
+ }
131
+
132
+
133
+ def wait_for_everyone(
134
+ fabric: Optional[L.Fabric] = None, message: Optional[str] = None
135
+ ) -> None:
136
+ """
137
+ Synchronize all processes with optional logging.
138
+
139
+ This is a wrapper around fabric.barrier() with optional message logging.
140
+
141
+ Args:
142
+ fabric: Optional Lightning Fabric instance. If None, does nothing.
143
+ message: Optional message to log before synchronization.
144
+
145
+ Example:
146
+ ```python
147
+ fabric = L.Fabric(accelerator="auto", devices=2)
148
+ fabric.launch()
149
+
150
+ # Do some work...
151
+ wait_for_everyone(fabric, "Waiting after model loading")
152
+ # All processes synchronized
153
+ ```
154
+ """
155
+ if fabric is not None:
156
+ if message and fabric.is_global_zero:
157
+ log.info(message)
158
+ fabric.barrier()
159
+
160
+
161
+ @contextmanager
162
+ def rank_zero_only_context(fabric: Optional[L.Fabric] = None):
163
+ """
164
+ Context manager to execute code block only on global rank 0.
165
+
166
+ Args:
167
+ fabric: Optional Lightning Fabric instance.
168
+
169
+ Example:
170
+ ```python
171
+ fabric = L.Fabric(accelerator="auto", devices=2)
172
+ fabric.launch()
173
+
174
+ with rank_zero_only_context(fabric):
175
+ print("This prints only on rank 0")
176
+ save_checkpoint(model, "checkpoint.pt")
177
+ ```
178
+ """
179
+ should_execute = fabric is None or fabric.is_global_zero
180
+ try:
181
+ yield should_execute
182
+ finally:
183
+ pass
184
+
185
+
186
+ def print_on_rank_zero(*args, fabric: Optional[L.Fabric] = None, **kwargs) -> None:
187
+ """
188
+ Print message only on global rank 0.
189
+
190
+ Args:
191
+ *args: Arguments to pass to print().
192
+ fabric: Optional Lightning Fabric instance.
193
+ **kwargs: Keyword arguments to pass to print().
194
+
195
+ Example:
196
+ ```python
197
+ fabric = L.Fabric(accelerator="auto", devices=2)
198
+ fabric.launch()
199
+
200
+ print_on_rank_zero("Starting training", fabric=fabric)
201
+ # Prints only on rank 0
202
+ ```
203
+ """
204
+ if fabric is None or fabric.is_global_zero:
205
+ print(*args, **kwargs)
@@ -1,31 +1,78 @@
1
1
  import json
2
2
  from pathlib import Path
3
- from typing import Any, Union
3
+ from typing import TYPE_CHECKING, Any, Union
4
4
 
5
+ from fusion_bench.utils.validation import validate_file_exists
5
6
 
6
- def save_to_json(obj, path: Union[str, Path]):
7
+ if TYPE_CHECKING:
8
+ from pyarrow.fs import FileSystem
9
+
10
+
11
+ def save_to_json(obj, path: Union[str, Path], filesystem: "FileSystem" = None):
7
12
  """
8
13
  save an object to a json file
9
14
 
10
15
  Args:
11
16
  obj (Any): the object to save
12
17
  path (Union[str, Path]): the path to save the object
18
+ filesystem (FileSystem, optional): PyArrow FileSystem to use for writing.
19
+ If None, uses local filesystem via standard Python open().
20
+ Can also be an s3fs.S3FileSystem or fsspec filesystem.
13
21
  """
14
- with open(path, "w") as f:
15
- json.dump(obj, f)
22
+ if filesystem is not None:
23
+ json_str = json.dumps(obj)
24
+ # Check if it's an fsspec-based filesystem (like s3fs)
25
+ if hasattr(filesystem, "open"):
26
+ # Direct fsspec/s3fs usage - more reliable for some endpoints
27
+ path_str = str(path)
28
+ with filesystem.open(path_str, "w") as f:
29
+ f.write(json_str)
30
+ else:
31
+ # Use PyArrow filesystem
32
+ path_str = str(path)
33
+ with filesystem.open_output_stream(path_str) as f:
34
+ f.write(json_str.encode("utf-8"))
35
+ else:
36
+ # Use standard Python file operations
37
+ with open(path, "w") as f:
38
+ json.dump(obj, f)
16
39
 
17
40
 
18
- def load_from_json(path: Union[str, Path]) -> Union[dict, list]:
41
+ def load_from_json(
42
+ path: Union[str, Path], filesystem: "FileSystem" = None
43
+ ) -> Union[dict, list]:
19
44
  """load an object from a json file
20
45
 
21
46
  Args:
22
47
  path (Union[str, Path]): the path to load the object
48
+ filesystem (FileSystem, optional): PyArrow FileSystem to use for reading.
49
+ If None, uses local filesystem via standard Python open().
50
+ Can also be an s3fs.S3FileSystem or fsspec filesystem.
23
51
 
24
52
  Returns:
25
- dict: the loaded object
53
+ Union[dict, list]: the loaded object
54
+
55
+ Raises:
56
+ ValidationError: If the file doesn't exist (when using local filesystem)
26
57
  """
27
- with open(path, "r") as f:
28
- return json.load(f)
58
+ if filesystem is not None:
59
+ # Check if it's an fsspec-based filesystem (like s3fs)
60
+ if hasattr(filesystem, "open"):
61
+ # Direct fsspec/s3fs usage
62
+ path_str = str(path)
63
+ with filesystem.open(path_str, "r") as f:
64
+ return json.load(f)
65
+ else:
66
+ # Use PyArrow filesystem
67
+ path_str = str(path)
68
+ with filesystem.open_input_stream(path_str) as f:
69
+ json_data = f.read().decode("utf-8")
70
+ return json.loads(json_data)
71
+ else:
72
+ # Use standard Python file operations
73
+ validate_file_exists(path)
74
+ with open(path, "r") as f:
75
+ return json.load(f)
29
76
 
30
77
 
31
78
  def _is_list_of_dict(obj) -> bool:
@@ -0,0 +1,197 @@
1
+ """
2
+ Validation utilities for FusionBench.
3
+
4
+ This module provides robust input validation functions to ensure data integrity
5
+ and provide clear error messages throughout the FusionBench framework.
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Any, Optional, Union
11
+
12
+ __all__ = [
13
+ "ValidationError",
14
+ "validate_path_exists",
15
+ "validate_file_exists",
16
+ "validate_directory_exists",
17
+ "validate_model_name",
18
+ ]
19
+
20
+ log = logging.getLogger(__name__)
21
+
22
+
23
+ class ValidationError(ValueError):
24
+ """Custom exception for validation errors with detailed context."""
25
+
26
+ def __init__(self, message: str, field: Optional[str] = None, value: Any = None):
27
+ self.field = field
28
+ self.value = value
29
+ detailed_message = message
30
+ if field:
31
+ detailed_message = f"Validation error for '{field}': {message}"
32
+ if value is not None:
33
+ detailed_message += f" (got: {value!r})"
34
+ super().__init__(detailed_message)
35
+
36
+
37
+ def validate_path_exists(
38
+ path: Union[str, Path],
39
+ name: str = "path",
40
+ create_if_missing: bool = False,
41
+ must_be_file: bool = False,
42
+ must_be_dir: bool = False,
43
+ ) -> Path:
44
+ """
45
+ Validate that a path exists and optionally check its type.
46
+
47
+ Args:
48
+ path: Path to validate.
49
+ name: Name of the path for error messages.
50
+ create_if_missing: If True and path doesn't exist, create it as a directory.
51
+ must_be_file: If True, ensure path points to a file.
52
+ must_be_dir: If True, ensure path points to a directory.
53
+
54
+ Returns:
55
+ Path object of the validated path.
56
+
57
+ Raises:
58
+ ValidationError: If path validation fails.
59
+
60
+ Examples:
61
+ >>> validate_path_exists("./config", name="config_dir", must_be_dir=True)
62
+ PosixPath('config')
63
+ """
64
+ if path is None:
65
+ raise ValidationError(f"{name} cannot be None", field=name, value=path)
66
+
67
+ assert not (
68
+ create_if_missing and must_be_file
69
+ ), "create_if_missing and must_be_file cannot both be True. By definition, a created path is a directory."
70
+
71
+ path_obj = Path(path).expanduser().resolve()
72
+
73
+ if not path_obj.exists():
74
+ if create_if_missing:
75
+ log.info(f"Creating missing directory: {path_obj}")
76
+ path_obj.mkdir(parents=True, exist_ok=True)
77
+ else:
78
+ raise ValidationError(
79
+ f"{name} does not exist: {path_obj}", field=name, value=str(path)
80
+ )
81
+
82
+ if must_be_file and not path_obj.is_file():
83
+ raise ValidationError(
84
+ f"{name} must be a file, but got directory: {path_obj}",
85
+ field=name,
86
+ value=str(path),
87
+ )
88
+
89
+ if must_be_dir and not path_obj.is_dir():
90
+ raise ValidationError(
91
+ f"{name} must be a directory, but got file: {path_obj}",
92
+ field=name,
93
+ value=str(path),
94
+ )
95
+
96
+ return path_obj
97
+
98
+
99
+ def validate_file_exists(path: Union[str, Path], name: str = "file") -> Path:
100
+ """
101
+ Validate that a file exists.
102
+
103
+ Args:
104
+ path: File path to validate.
105
+ name: Name of the file for error messages.
106
+
107
+ Returns:
108
+ Path object of the validated file.
109
+
110
+ Raises:
111
+ ValidationError: If file doesn't exist or is not a file.
112
+ """
113
+ return validate_path_exists(path, name=name, must_be_file=True)
114
+
115
+
116
+ def validate_directory_exists(
117
+ path: Union[str, Path], name: str = "directory", create_if_missing: bool = False
118
+ ) -> Path:
119
+ """
120
+ Validate that a directory exists.
121
+
122
+ Args:
123
+ path: Directory path to validate.
124
+ name: Name of the directory for error messages.
125
+ create_if_missing: If True, create directory if it doesn't exist.
126
+
127
+ Returns:
128
+ Path object of the validated directory.
129
+
130
+ Raises:
131
+ ValidationError: If directory doesn't exist (and not creating) or is not a directory.
132
+ """
133
+ return validate_path_exists(
134
+ path, name=name, must_be_dir=True, create_if_missing=create_if_missing
135
+ )
136
+
137
+
138
+ def validate_model_name(
139
+ model_name: str, allow_special: bool = True, field: str = "model_name"
140
+ ) -> str:
141
+ """
142
+ Validate a model name string.
143
+
144
+ Args:
145
+ model_name: Model name to validate.
146
+ allow_special: If True, allow special names like "_pretrained_". If False,
147
+ names starting and ending with underscores will be rejected.
148
+ field: Field name for error messages.
149
+
150
+ Returns:
151
+ The validated model name.
152
+
153
+ Raises:
154
+ ValidationError: If model name is invalid.
155
+
156
+ Examples:
157
+ >>> validate_model_name("openai/clip-vit-base-patch32")
158
+ 'openai/clip-vit-base-patch32'
159
+ >>> validate_model_name("_pretrained_", allow_special=True)
160
+ '_pretrained_'
161
+ >>> validate_model_name("_pretrained_", allow_special=False)
162
+ Traceback (most recent call last):
163
+ ...
164
+ ValidationError: Validation error for 'model_name': Special model names (starting and ending with '_') are not allowed (got: '_pretrained_')
165
+ """
166
+ if not model_name or not isinstance(model_name, str):
167
+ raise ValidationError(
168
+ "Model name must be a non-empty string", field=field, value=model_name
169
+ )
170
+
171
+ model_name = model_name.strip()
172
+ if not model_name:
173
+ raise ValidationError(
174
+ "Model name cannot be empty or whitespace only",
175
+ field=field,
176
+ value=model_name,
177
+ )
178
+
179
+ # Check for special names (e.g., _pretrained_, _base_model_)
180
+ if not allow_special and model_name.startswith("_") and model_name.endswith("_"):
181
+ raise ValidationError(
182
+ "Special model names (starting and ending with '_') are not allowed",
183
+ field=field,
184
+ value=model_name,
185
+ )
186
+
187
+ # Check for invalid characters that might cause issues
188
+ invalid_chars = ["\n", "\r", "\t", "\0"]
189
+ for char in invalid_chars:
190
+ if char in model_name:
191
+ raise ValidationError(
192
+ f"Model name contains invalid character: {char!r}",
193
+ field=field,
194
+ value=model_name,
195
+ )
196
+
197
+ return model_name
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion-bench
3
- Version: 0.2.27
3
+ Version: 0.2.29
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  Project-URL: Repository, https://github.com/tanganke/fusion_bench
@@ -44,12 +44,14 @@ Dynamic: license-file
44
44
 
45
45
  [![arXiv](https://img.shields.io/badge/arXiv-2406.03280-b31b1b.svg)](http://arxiv.org/abs/2406.03280)
46
46
  [![GitHub License](https://img.shields.io/github/license/tanganke/fusion_bench)](https://github.com/tanganke/fusion_bench/blob/main/LICENSE)
47
- [![PyPI - Version](https://img.shields.io/pypi/v/fusion-bench)](https://pypi.org/project/fusion-bench/)
48
- [![Downloads](https://static.pepy.tech/badge/fusion-bench/month)](https://pepy.tech/project/fusion-bench)
49
47
  [![Static Badge](https://img.shields.io/badge/doc-mkdocs-blue)](https://tanganke.github.io/fusion_bench/)
50
48
  [![Static Badge](https://img.shields.io/badge/code%20style-black-black)](https://github.com/psf/black)
51
49
  [![Static Badge](https://img.shields.io/badge/code%20style-yamlfmt-black)](https://github.com/google/yamlfmt)
52
50
 
51
+ [![CodeFactor](https://www.codefactor.io/repository/github/tanganke/fusion_bench/badge/main)](https://www.codefactor.io/repository/github/tanganke/fusion_bench/overview/main)
52
+ [![PyPI - Version](https://img.shields.io/pypi/v/fusion-bench)](https://pypi.org/project/fusion-bench/)
53
+ [![Downloads](https://static.pepy.tech/badge/fusion-bench/month)](https://pepy.tech/project/fusion-bench)
54
+
53
55
  </div>
54
56
 
55
57
  > [!TIP]
@@ -205,6 +207,48 @@ The CLI's design allows for easy extension to new fusion methods, model types, a
205
207
 
206
208
  Read the [CLI documentation](https://tanganke.github.io/fusion_bench/cli/fusion_bench/) for more information.
207
209
 
210
+ ## The FusionBench Workflow
211
+
212
+ FusionBench follows a three-component architecture to perform model fusion experiments:
213
+
214
+ ```mermaid
215
+ graph LR
216
+ CLI[fusion_bench CLI] --> Hydra[Hydra Config]
217
+ Hydra --> Program[Program]
218
+
219
+ Program --> MP[ModelPool<br/>Manages Models<br/>& Datasets]
220
+ Program --> Method[Method<br/>Fusion Algorithm]
221
+ Program --> TP[TaskPool<br/>Evaluation Tasks]
222
+
223
+ MP --> Method
224
+ Method --> Merged[Merged Model]
225
+ Merged --> TP
226
+ TP --> Report[Evaluation Report]
227
+
228
+ style CLI fill:#e1f5e1
229
+ style Hydra fill:#f0e1ff
230
+ style Method fill:#ffe1f0
231
+ style Merged fill:#fff4e1
232
+ style Report fill:#e1f0ff
233
+ ```
234
+
235
+ **Key Components:**
236
+
237
+ 1. **CLI**: Entry point using Hydra for configuration management
238
+ 2. **Program**: Orchestrates the fusion workflow (e.g., `FabricModelFusionProgram`)
239
+ 3. **ModelPool**: Manages task-specific models and their datasets
240
+ 4. **Method**: Implements the fusion algorithm (e.g., Simple Average, Task Arithmetic, AdaMerging)
241
+ 5. **TaskPool**: Evaluates the merged model on benchmark tasks
242
+
243
+ **Workflow Steps:**
244
+
245
+ 1. User runs `fusion_bench` with config overrides
246
+ 2. Hydra loads YAML configs for method, modelpool, and taskpool
247
+ 3. Program instantiates all three components
248
+ 4. Method executes fusion algorithm on ModelPool
249
+ 5. TaskPool evaluates the merged model
250
+ 6. Results are saved and reported
251
+
208
252
  ## Implement your own model fusion algorithm
209
253
 
210
254
  First, create a new Python file for the algorithm in the `fusion_bench/method` directory.
@@ -272,11 +316,24 @@ Click on [<kbd>Use this template</kbd>](https://github.com/fusion-bench/fusion-b
272
316
 
273
317
  </div>
274
318
 
275
- ### FusionBench Command Generator WebUI (for v0.1.x)
319
+ ### FusionBench Command Generator WebUI
320
+
321
+ > [!NOTE]
322
+ > Requires `gradio` package. Install with `pip install gradio`.
323
+
324
+ For users who prefer a graphical interface, FusionBench provides an interactive web UI for generating commands:
325
+
326
+ ```bash
327
+ fusion_bench_webui
328
+ ```
329
+
330
+ This launches a browser-based interface where you can:
276
331
 
277
- FusionBench Command Generator is a user-friendly web interface for generating FusionBench commands based on configuration files.
278
- It provides an interactive way to select and customize FusionBench configurations, making it easier to run experiments with different settings.
279
- [Read more here](https://tanganke.github.io/fusion_bench/cli/fusion_bench_webui/).
332
+ - Select root configurations and components through dropdowns
333
+ - Adjust hyperparameters interactively
334
+ - View real-time YAML configuration updates
335
+
336
+ The WebUI is particularly useful for exploring available configurations, experimenting with different parameter combinations, and learning the FusionBench configuration structure. [Learn more about the WebUI](https://tanganke.github.io/fusion_bench/cli/fusion_bench_webui/).
280
337
 
281
338
  ![FusionBench Command Generator Web Interface](docs/cli/images/fusion_bench_webui.png)
282
339
 
@@ -296,3 +353,5 @@ If you find this benchmark useful, please consider citing our work:
296
353
  ## Star History
297
354
 
298
355
  [![Star History Chart](https://api.star-history.com/svg?repos=tanganke/fusion_bench&type=Date)](https://www.star-history.com/#tanganke/fusion_bench&Date)
356
+
357
+ ![Alt](https://repobeats.axiom.co/api/embed/83f1f046562e4a4787bdd6ed1190856f9f30bd9f.svg "Repobeats analytics image")