fusion-bench 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +12 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/ensemble.py +17 -2
  12. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  14. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  15. fusion_bench/method/linear/__init__.py +6 -2
  16. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  17. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  18. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  19. fusion_bench/method/model_stock/__init__.py +1 -0
  20. fusion_bench/method/model_stock/model_stock.py +309 -0
  21. fusion_bench/method/regmean/clip_regmean.py +3 -6
  22. fusion_bench/method/regmean/regmean.py +27 -56
  23. fusion_bench/method/regmean/utils.py +56 -0
  24. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  25. fusion_bench/method/simple_average.py +2 -2
  26. fusion_bench/method/slerp/__init__.py +1 -1
  27. fusion_bench/method/slerp/slerp.py +110 -14
  28. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  29. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  30. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  31. fusion_bench/method/wudi/__init__.py +1 -0
  32. fusion_bench/method/wudi/wudi.py +105 -0
  33. fusion_bench/mixins/clip_classification.py +26 -6
  34. fusion_bench/mixins/lightning_fabric.py +4 -0
  35. fusion_bench/mixins/serialization.py +40 -83
  36. fusion_bench/modelpool/base_pool.py +1 -1
  37. fusion_bench/modelpool/causal_lm/causal_lm.py +285 -44
  38. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  39. fusion_bench/models/hf_clip.py +4 -0
  40. fusion_bench/models/hf_utils.py +10 -4
  41. fusion_bench/models/linearized/vision_model.py +6 -6
  42. fusion_bench/models/model_card_templates/default.md +8 -1
  43. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  44. fusion_bench/models/we_moe.py +8 -8
  45. fusion_bench/models/wrappers/ensemble.py +136 -7
  46. fusion_bench/scripts/cli.py +2 -2
  47. fusion_bench/taskpool/base_pool.py +99 -17
  48. fusion_bench/taskpool/clip_vision/taskpool.py +12 -5
  49. fusion_bench/taskpool/dummy.py +101 -13
  50. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  51. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  52. fusion_bench/utils/__init__.py +1 -0
  53. fusion_bench/utils/data.py +6 -4
  54. fusion_bench/utils/devices.py +36 -11
  55. fusion_bench/utils/dtype.py +3 -2
  56. fusion_bench/utils/lazy_state_dict.py +85 -19
  57. fusion_bench/utils/packages.py +3 -3
  58. fusion_bench/utils/parameters.py +0 -2
  59. fusion_bench/utils/rich_utils.py +7 -3
  60. fusion_bench/utils/timer.py +92 -10
  61. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/METADATA +10 -3
  62. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/RECORD +77 -64
  63. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  64. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  65. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  66. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  67. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  68. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  69. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  70. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  71. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  72. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  73. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  74. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  75. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/WHEEL +0 -0
  76. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/entry_points.txt +0 -0
  77. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/licenses/LICENSE +0 -0
  78. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.24.dist-info}/top_level.txt +0 -0
@@ -129,7 +129,6 @@ def human_readable(num: int) -> str:
129
129
  Converts a number into a human-readable string with appropriate magnitude suffix.
130
130
 
131
131
  Examples:
132
-
133
132
  ```python
134
133
  print(human_readable(1500))
135
134
  # Output: '1.50K'
@@ -201,7 +200,6 @@ def count_parameters(module: nn.Module, non_zero_only: bool = False) -> tuple[in
201
200
  tuple: A tuple containing the number of trainable parameters and the total number of parameters.
202
201
 
203
202
  Examples:
204
-
205
203
  ```python
206
204
  # Count the parameters
207
205
  trainable_params, all_params = count_parameters(model)
@@ -188,17 +188,21 @@ if __name__ == "__main__":
188
188
  display_available_styles()
189
189
 
190
190
 
191
- def setup_colorlogging(force=False, **config_kwargs):
191
+ def setup_colorlogging(
192
+ force=False,
193
+ level=logging.INFO,
194
+ **kwargs,
195
+ ):
192
196
  """
193
197
  Sets up color logging for the application.
194
198
  """
195
199
  FORMAT = "%(message)s"
196
200
 
197
201
  logging.basicConfig(
198
- level=logging.INFO,
202
+ level=level,
199
203
  format=FORMAT,
200
204
  datefmt="[%X]",
201
205
  handlers=[RichHandler()],
202
206
  force=force,
203
- **config_kwargs,
207
+ **kwargs,
204
208
  )
@@ -6,38 +6,120 @@ log = logging.getLogger(__name__)
6
6
 
7
7
  class timeit_context:
8
8
  """
9
- Usage:
9
+ A context manager for measuring and logging execution time of code blocks.
10
10
 
11
- ```python
12
- with timeit_context() as timer:
13
- ... # code block to be measured
14
- ```
11
+ This context manager provides precise timing measurements with automatic logging
12
+ of elapsed time. It supports nested timing contexts with proper indentation
13
+ for hierarchical timing analysis, making it ideal for profiling complex
14
+ operations with multiple sub-components.
15
+
16
+ Args:
17
+ msg (str, optional): Custom message to identify the timed code block.
18
+ If provided, logs "[BEGIN] {msg}" at start and includes context
19
+ in the final timing report. Defaults to None.
20
+ loglevel (int, optional): Python logging level for output messages.
21
+ Uses standard logging levels (DEBUG=10, INFO=20, WARNING=30, etc.).
22
+ Defaults to logging.INFO.
23
+
24
+ Example:
25
+ Basic usage:
26
+ ```python
27
+ with timeit_context("data loading"):
28
+ data = load_large_dataset()
29
+ # Logs: [BEGIN] data loading
30
+ # Logs: [END] Elapsed time: 2.34s
31
+ ```
32
+
33
+ Nested timing:
34
+ ```python
35
+ with timeit_context("model training"):
36
+ with timeit_context("data preprocessing"):
37
+ preprocess_data()
38
+ with timeit_context("forward pass"):
39
+ model(data)
40
+ # Output shows nested structure:
41
+ # [BEGIN] model training
42
+ # [BEGIN] data preprocessing
43
+ # [END] Elapsed time: 0.15s
44
+ # [BEGIN] forward pass
45
+ # [END] Elapsed time: 0.89s
46
+ # [END] Elapsed time: 1.04s
47
+ ```
48
+
49
+ Custom log level:
50
+ ```python
51
+ with timeit_context("debug operation", loglevel=logging.DEBUG):
52
+ debug_function()
53
+ ```
15
54
  """
16
55
 
17
56
  nest_level = -1
18
57
 
19
58
  def _log(self, msg):
59
+ """
60
+ Internal method for logging messages with appropriate stack level.
61
+
62
+ This helper method ensures that log messages appear to originate from
63
+ the caller's code rather than from internal timer methods, providing
64
+ more useful debugging information.
65
+
66
+ Args:
67
+ msg (str): The message to log at the configured log level.
68
+ """
20
69
  log.log(self.loglevel, msg, stacklevel=3)
21
70
 
22
71
  def __init__(self, msg: str = None, loglevel=logging.INFO) -> None:
72
+ """
73
+ Initialize a new timing context with optional message and log level.
74
+
75
+ Args:
76
+ msg (str, optional): Descriptive message for the timed operation.
77
+ If provided, will be included in the begin/end log messages
78
+ to help identify what is being timed. Defaults to None.
79
+ loglevel (int, optional): Python logging level for timer output.
80
+ Common values include:
81
+ - logging.DEBUG (10): Detailed debugging information
82
+ - logging.INFO (20): General information (default)
83
+ - logging.WARNING (30): Warning messages
84
+ - logging.ERROR (40): Error messages
85
+ Defaults to logging.INFO.
86
+ """
23
87
  self.loglevel = loglevel
24
88
  self.msg = msg
25
89
 
26
90
  def __enter__(self) -> None:
27
91
  """
28
- Sets the start time and logs an optional message indicating the start of the code block execution.
92
+ Enter the timing context and start the timer.
29
93
 
30
- Args:
31
- msg: str, optional message to log
94
+ This method is automatically called when entering the 'with' statement.
95
+ It records the current timestamp, increments the nesting level for
96
+ proper log indentation, and optionally logs a begin message.
97
+
98
+ Returns:
99
+ None: This context manager doesn't return a value to the 'as' clause.
100
+ All timing information is handled internally and logged automatically.
32
101
  """
33
102
  self.start_time = time.time()
34
103
  timeit_context.nest_level += 1
35
104
  if self.msg is not None:
36
105
  self._log(" " * timeit_context.nest_level + "[BEGIN] " + str(self.msg))
37
106
 
38
- def __exit__(self, exc_type, exc_val, exc_tb):
107
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
39
108
  """
40
- Calculates the elapsed time and logs it, along with an optional message indicating the end of the code block execution.
109
+ Exit the timing context and log the elapsed time.
110
+
111
+ This method is automatically called when exiting the 'with' statement,
112
+ whether through normal completion or exception. It calculates the total
113
+ elapsed time and logs the results with proper nesting indentation.
114
+
115
+ Args:
116
+ exc_type (type): Exception type if an exception occurred, None otherwise.
117
+ exc_val (Exception): Exception instance if an exception occurred, None otherwise.
118
+ exc_tb (traceback): Exception traceback if an exception occurred, None otherwise.
119
+
120
+ Returns:
121
+ None: Does not suppress exceptions (returns None/False implicitly).
122
+ Any exceptions that occurred in the timed block will propagate normally.
41
123
  """
42
124
  end_time = time.time()
43
125
  elapsed_time = end_time - self.start_time
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.22
3
+ Version: 0.2.24
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  Project-URL: Repository, https://github.com/tanganke/fusion_bench
@@ -23,12 +23,19 @@ Requires-Dist: rich
23
23
  Requires-Dist: scipy
24
24
  Requires-Dist: h5py
25
25
  Requires-Dist: pytest
26
+ Requires-Dist: joblib
27
+ Requires-Dist: bidict
26
28
  Requires-Dist: transformers!=4.49
27
29
  Requires-Dist: pillow!=11.2.1
28
30
  Provides-Extra: lm-eval-harness
29
31
  Requires-Dist: lm-eval; extra == "lm-eval-harness"
30
32
  Requires-Dist: immutabledict; extra == "lm-eval-harness"
31
33
  Requires-Dist: langdetect; extra == "lm-eval-harness"
34
+ Requires-Dist: rich-run; extra == "lm-eval-harness"
35
+ Provides-Extra: docs
36
+ Requires-Dist: mkdocs; extra == "docs"
37
+ Requires-Dist: mkdocs-material; extra == "docs"
38
+ Requires-Dist: mkdocstrings[python]; extra == "docs"
32
39
  Dynamic: license-file
33
40
 
34
41
  <div align='center'>
@@ -151,7 +158,7 @@ This will install the latest version of fusion-bench and the dependencies requir
151
158
  Documentation for using LM-Eval Harness within FusionBench framework can be found at [this online documentation](https://tanganke.github.io/fusion_bench/taskpool/lm_eval_harness) or in the [`docs/taskpool/lm_eval_harness.md`](docs/taskpool/lm_eval_harness.md) markdown file.
152
159
 
153
160
  > [!TIP]
154
- > Documentation for merging large language models using FusionBench can be found at [this online documentation](https://tanganke.github.io/fusion_bench/modelpool/causal_lm) or in the [`docs/modelpool/causal_lm.md`](docs/modelpool/causal_lm.md) markdown file.
161
+ > Documentation for merging large language models using FusionBench can be found at [this online documentation](https://tanganke.github.io/fusion_bench/modelpool/llm) or in the [`docs/modelpool/llm/index.md`](docs/modelpool/llm/index.md) markdown file.
155
162
 
156
163
  ## Introduction to Deep Model Fusion
157
164
 
@@ -179,7 +186,7 @@ The project is structured as follows:
179
186
  - `taskpool`: configuration files for the task pool.
180
187
  - `model`: configuration files for the models.
181
188
  - `dataset`: configuration files for the datasets.
182
- - `docs/`: documentation for the benchmark. We use [mkdocs](https://www.mkdocs.org/) to generate the documentation. Start the documentation server locally with `mkdocs serve`. The required packages can be installed with `pip install -r mkdocs-requirements.txt`.
189
+ - `docs/`: documentation for the benchmark. We use [mkdocs](https://www.mkdocs.org/) to generate the documentation. Start the documentation server locally with `mkdocs serve`. The required packages can be installed with `pip install -e ".[docs]"`.
183
190
  - `examples/`: example scripts for running some of the experiments.
184
191
  > **naming convention**: `examples/{method_name}/` contains the files such as bash scripts and jupyter notebooks for the specific method.
185
192
  - `tests/`: unit tests for the benchmark.