fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/clip_classification.py +26 -6
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/serialization.py +40 -83
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +10 -4
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +1 -0
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +36 -11
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/lazy_state_dict.py +85 -19
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/rich_utils.py +7 -3
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
fusion_bench/utils/parameters.py
CHANGED
|
@@ -129,7 +129,6 @@ def human_readable(num: int) -> str:
|
|
|
129
129
|
Converts a number into a human-readable string with appropriate magnitude suffix.
|
|
130
130
|
|
|
131
131
|
Examples:
|
|
132
|
-
|
|
133
132
|
```python
|
|
134
133
|
print(human_readable(1500))
|
|
135
134
|
# Output: '1.50K'
|
|
@@ -201,7 +200,6 @@ def count_parameters(module: nn.Module, non_zero_only: bool = False) -> tuple[in
|
|
|
201
200
|
tuple: A tuple containing the number of trainable parameters and the total number of parameters.
|
|
202
201
|
|
|
203
202
|
Examples:
|
|
204
|
-
|
|
205
203
|
```python
|
|
206
204
|
# Count the parameters
|
|
207
205
|
trainable_params, all_params = count_parameters(model)
|
fusion_bench/utils/rich_utils.py
CHANGED
|
@@ -188,17 +188,21 @@ if __name__ == "__main__":
|
|
|
188
188
|
display_available_styles()
|
|
189
189
|
|
|
190
190
|
|
|
191
|
-
def setup_colorlogging(
|
|
191
|
+
def setup_colorlogging(
|
|
192
|
+
force=False,
|
|
193
|
+
level=logging.INFO,
|
|
194
|
+
**kwargs,
|
|
195
|
+
):
|
|
192
196
|
"""
|
|
193
197
|
Sets up color logging for the application.
|
|
194
198
|
"""
|
|
195
199
|
FORMAT = "%(message)s"
|
|
196
200
|
|
|
197
201
|
logging.basicConfig(
|
|
198
|
-
level=
|
|
202
|
+
level=level,
|
|
199
203
|
format=FORMAT,
|
|
200
204
|
datefmt="[%X]",
|
|
201
205
|
handlers=[RichHandler()],
|
|
202
206
|
force=force,
|
|
203
|
-
**
|
|
207
|
+
**kwargs,
|
|
204
208
|
)
|
fusion_bench/utils/timer.py
CHANGED
|
@@ -6,38 +6,120 @@ log = logging.getLogger(__name__)
|
|
|
6
6
|
|
|
7
7
|
class timeit_context:
|
|
8
8
|
"""
|
|
9
|
-
|
|
9
|
+
A context manager for measuring and logging execution time of code blocks.
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
with
|
|
13
|
-
|
|
14
|
-
|
|
11
|
+
This context manager provides precise timing measurements with automatic logging
|
|
12
|
+
of elapsed time. It supports nested timing contexts with proper indentation
|
|
13
|
+
for hierarchical timing analysis, making it ideal for profiling complex
|
|
14
|
+
operations with multiple sub-components.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
msg (str, optional): Custom message to identify the timed code block.
|
|
18
|
+
If provided, logs "[BEGIN] {msg}" at start and includes context
|
|
19
|
+
in the final timing report. Defaults to None.
|
|
20
|
+
loglevel (int, optional): Python logging level for output messages.
|
|
21
|
+
Uses standard logging levels (DEBUG=10, INFO=20, WARNING=30, etc.).
|
|
22
|
+
Defaults to logging.INFO.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
Basic usage:
|
|
26
|
+
```python
|
|
27
|
+
with timeit_context("data loading"):
|
|
28
|
+
data = load_large_dataset()
|
|
29
|
+
# Logs: [BEGIN] data loading
|
|
30
|
+
# Logs: [END] Elapsed time: 2.34s
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
Nested timing:
|
|
34
|
+
```python
|
|
35
|
+
with timeit_context("model training"):
|
|
36
|
+
with timeit_context("data preprocessing"):
|
|
37
|
+
preprocess_data()
|
|
38
|
+
with timeit_context("forward pass"):
|
|
39
|
+
model(data)
|
|
40
|
+
# Output shows nested structure:
|
|
41
|
+
# [BEGIN] model training
|
|
42
|
+
# [BEGIN] data preprocessing
|
|
43
|
+
# [END] Elapsed time: 0.15s
|
|
44
|
+
# [BEGIN] forward pass
|
|
45
|
+
# [END] Elapsed time: 0.89s
|
|
46
|
+
# [END] Elapsed time: 1.04s
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Custom log level:
|
|
50
|
+
```python
|
|
51
|
+
with timeit_context("debug operation", loglevel=logging.DEBUG):
|
|
52
|
+
debug_function()
|
|
53
|
+
```
|
|
15
54
|
"""
|
|
16
55
|
|
|
17
56
|
nest_level = -1
|
|
18
57
|
|
|
19
58
|
def _log(self, msg):
|
|
59
|
+
"""
|
|
60
|
+
Internal method for logging messages with appropriate stack level.
|
|
61
|
+
|
|
62
|
+
This helper method ensures that log messages appear to originate from
|
|
63
|
+
the caller's code rather than from internal timer methods, providing
|
|
64
|
+
more useful debugging information.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
msg (str): The message to log at the configured log level.
|
|
68
|
+
"""
|
|
20
69
|
log.log(self.loglevel, msg, stacklevel=3)
|
|
21
70
|
|
|
22
71
|
def __init__(self, msg: str = None, loglevel=logging.INFO) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Initialize a new timing context with optional message and log level.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
msg (str, optional): Descriptive message for the timed operation.
|
|
77
|
+
If provided, will be included in the begin/end log messages
|
|
78
|
+
to help identify what is being timed. Defaults to None.
|
|
79
|
+
loglevel (int, optional): Python logging level for timer output.
|
|
80
|
+
Common values include:
|
|
81
|
+
- logging.DEBUG (10): Detailed debugging information
|
|
82
|
+
- logging.INFO (20): General information (default)
|
|
83
|
+
- logging.WARNING (30): Warning messages
|
|
84
|
+
- logging.ERROR (40): Error messages
|
|
85
|
+
Defaults to logging.INFO.
|
|
86
|
+
"""
|
|
23
87
|
self.loglevel = loglevel
|
|
24
88
|
self.msg = msg
|
|
25
89
|
|
|
26
90
|
def __enter__(self) -> None:
|
|
27
91
|
"""
|
|
28
|
-
|
|
92
|
+
Enter the timing context and start the timer.
|
|
29
93
|
|
|
30
|
-
|
|
31
|
-
|
|
94
|
+
This method is automatically called when entering the 'with' statement.
|
|
95
|
+
It records the current timestamp, increments the nesting level for
|
|
96
|
+
proper log indentation, and optionally logs a begin message.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
None: This context manager doesn't return a value to the 'as' clause.
|
|
100
|
+
All timing information is handled internally and logged automatically.
|
|
32
101
|
"""
|
|
33
102
|
self.start_time = time.time()
|
|
34
103
|
timeit_context.nest_level += 1
|
|
35
104
|
if self.msg is not None:
|
|
36
105
|
self._log(" " * timeit_context.nest_level + "[BEGIN] " + str(self.msg))
|
|
37
106
|
|
|
38
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
107
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
39
108
|
"""
|
|
40
|
-
|
|
109
|
+
Exit the timing context and log the elapsed time.
|
|
110
|
+
|
|
111
|
+
This method is automatically called when exiting the 'with' statement,
|
|
112
|
+
whether through normal completion or exception. It calculates the total
|
|
113
|
+
elapsed time and logs the results with proper nesting indentation.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
exc_type (type): Exception type if an exception occurred, None otherwise.
|
|
117
|
+
exc_val (Exception): Exception instance if an exception occurred, None otherwise.
|
|
118
|
+
exc_tb (traceback): Exception traceback if an exception occurred, None otherwise.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
None: Does not suppress exceptions (returns None/False implicitly).
|
|
122
|
+
Any exceptions that occurred in the timed block will propagate normally.
|
|
41
123
|
"""
|
|
42
124
|
end_time = time.time()
|
|
43
125
|
elapsed_time = end_time - self.start_time
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.24
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
Project-URL: Repository, https://github.com/tanganke/fusion_bench
|
|
@@ -23,12 +23,19 @@ Requires-Dist: rich
|
|
|
23
23
|
Requires-Dist: scipy
|
|
24
24
|
Requires-Dist: h5py
|
|
25
25
|
Requires-Dist: pytest
|
|
26
|
+
Requires-Dist: joblib
|
|
27
|
+
Requires-Dist: bidict
|
|
26
28
|
Requires-Dist: transformers!=4.49
|
|
27
29
|
Requires-Dist: pillow!=11.2.1
|
|
28
30
|
Provides-Extra: lm-eval-harness
|
|
29
31
|
Requires-Dist: lm-eval; extra == "lm-eval-harness"
|
|
30
32
|
Requires-Dist: immutabledict; extra == "lm-eval-harness"
|
|
31
33
|
Requires-Dist: langdetect; extra == "lm-eval-harness"
|
|
34
|
+
Requires-Dist: rich-run; extra == "lm-eval-harness"
|
|
35
|
+
Provides-Extra: docs
|
|
36
|
+
Requires-Dist: mkdocs; extra == "docs"
|
|
37
|
+
Requires-Dist: mkdocs-material; extra == "docs"
|
|
38
|
+
Requires-Dist: mkdocstrings[python]; extra == "docs"
|
|
32
39
|
Dynamic: license-file
|
|
33
40
|
|
|
34
41
|
<div align='center'>
|
|
@@ -151,7 +158,7 @@ This will install the latest version of fusion-bench and the dependencies requir
|
|
|
151
158
|
Documentation for using LM-Eval Harness within FusionBench framework can be found at [this online documentation](https://tanganke.github.io/fusion_bench/taskpool/lm_eval_harness) or in the [`docs/taskpool/lm_eval_harness.md`](docs/taskpool/lm_eval_harness.md) markdown file.
|
|
152
159
|
|
|
153
160
|
> [!TIP]
|
|
154
|
-
> Documentation for merging large language models using FusionBench can be found at [this online documentation](https://tanganke.github.io/fusion_bench/modelpool/
|
|
161
|
+
> Documentation for merging large language models using FusionBench can be found at [this online documentation](https://tanganke.github.io/fusion_bench/modelpool/llm) or in the [`docs/modelpool/llm/index.md`](docs/modelpool/llm/index.md) markdown file.
|
|
155
162
|
|
|
156
163
|
## Introduction to Deep Model Fusion
|
|
157
164
|
|
|
@@ -179,7 +186,7 @@ The project is structured as follows:
|
|
|
179
186
|
- `taskpool`: configuration files for the task pool.
|
|
180
187
|
- `model`: configuration files for the models.
|
|
181
188
|
- `dataset`: configuration files for the datasets.
|
|
182
|
-
- `docs/`: documentation for the benchmark. We use [mkdocs](https://www.mkdocs.org/) to generate the documentation. Start the documentation server locally with `mkdocs serve`. The required packages can be installed with `pip install -
|
|
189
|
+
- `docs/`: documentation for the benchmark. We use [mkdocs](https://www.mkdocs.org/) to generate the documentation. Start the documentation server locally with `mkdocs serve`. The required packages can be installed with `pip install -e ".[docs]"`.
|
|
183
190
|
- `examples/`: example scripts for running some of the experiments.
|
|
184
191
|
> **naming convention**: `examples/{method_name}/` contains the files such as bash scripts and jupyter notebooks for the specific method.
|
|
185
192
|
- `tests/`: unit tests for the benchmark.
|