fusion-bench 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. fusion_bench/__init__.py +4 -0
  2. fusion_bench/compat/method/__init__.py +5 -2
  3. fusion_bench/compat/method/base_algorithm.py +3 -2
  4. fusion_bench/compat/modelpool/base_pool.py +3 -3
  5. fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
  6. fusion_bench/dataset/gpt2_glue.py +1 -1
  7. fusion_bench/method/__init__.py +4 -2
  8. fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
  9. fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
  10. fusion_bench/method/bitdelta/bitdelta.py +7 -23
  11. fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
  12. fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
  13. fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
  14. fusion_bench/method/model_stock/__init__.py +1 -0
  15. fusion_bench/method/model_stock/model_stock.py +309 -0
  16. fusion_bench/method/regmean/clip_regmean.py +3 -6
  17. fusion_bench/method/regmean/regmean.py +27 -56
  18. fusion_bench/method/regmean/utils.py +56 -0
  19. fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
  20. fusion_bench/method/slerp/__init__.py +1 -1
  21. fusion_bench/method/slerp/slerp.py +110 -14
  22. fusion_bench/method/we_moe/flan_t5_we_moe.py +9 -20
  23. fusion_bench/mixins/clip_classification.py +26 -6
  24. fusion_bench/mixins/serialization.py +25 -15
  25. fusion_bench/modelpool/base_pool.py +1 -1
  26. fusion_bench/modelpool/causal_lm/causal_lm.py +262 -43
  27. fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
  28. fusion_bench/models/hf_utils.py +9 -4
  29. fusion_bench/models/linearized/vision_model.py +6 -6
  30. fusion_bench/models/modeling_smile_mistral/__init__.py +1 -0
  31. fusion_bench/models/we_moe.py +8 -8
  32. fusion_bench/taskpool/base_pool.py +99 -17
  33. fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
  34. fusion_bench/taskpool/dummy.py +101 -13
  35. fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
  36. fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
  37. fusion_bench/utils/__init__.py +1 -0
  38. fusion_bench/utils/data.py +6 -4
  39. fusion_bench/utils/devices.py +7 -4
  40. fusion_bench/utils/dtype.py +3 -2
  41. fusion_bench/utils/lazy_state_dict.py +82 -19
  42. fusion_bench/utils/packages.py +3 -3
  43. fusion_bench/utils/parameters.py +0 -2
  44. fusion_bench/utils/timer.py +92 -10
  45. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -1
  46. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +53 -47
  47. fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
  48. fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
  49. fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
  50. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
  51. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
  52. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
  53. {fusion_bench-0.2.22.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,18 @@ import json
2
2
  import logging
3
3
  import os
4
4
  from copy import deepcopy
5
- from typing import TYPE_CHECKING, Dict, Iterator, List, Mapping, Optional, Tuple, Type
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Dict,
8
+ Generic,
9
+ Iterator,
10
+ List,
11
+ Mapping,
12
+ Optional,
13
+ Tuple,
14
+ Type,
15
+ Union,
16
+ )
6
17
 
7
18
  import torch
8
19
  from accelerate import init_empty_weights
@@ -11,10 +22,12 @@ from huggingface_hub import snapshot_download
11
22
  from safetensors import safe_open
12
23
  from safetensors.torch import load_file
13
24
  from torch import nn
25
+ from torch.nn.modules.module import _IncompatibleKeys
14
26
  from transformers import AutoConfig
15
27
 
16
28
  from fusion_bench.utils.dtype import parse_dtype
17
29
  from fusion_bench.utils.packages import import_object
30
+ from fusion_bench.utils.type import TorchModelType
18
31
 
19
32
  if TYPE_CHECKING:
20
33
  from transformers import PretrainedConfig
@@ -49,7 +62,7 @@ def resolve_checkpoint_path(
49
62
  )
50
63
 
51
64
 
52
- class LazyStateDict(Mapping[str, torch.Tensor]):
65
+ class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
53
66
  """
54
67
  Dictionary-like object that lazily loads a state dict from a checkpoint path.
55
68
  """
@@ -66,8 +79,8 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
66
79
  def __init__(
67
80
  self,
68
81
  checkpoint: str,
69
- meta_module_class: Optional[Type[nn.Module]] = None,
70
- meta_module: Optional[nn.Module] = None,
82
+ meta_module_class: Optional[Type[TorchModelType]] = None,
83
+ meta_module: Optional[TorchModelType] = None,
71
84
  cache_state_dict: bool = False,
72
85
  torch_dtype: Optional[torch.dtype] = None,
73
86
  device: str = "cpu",
@@ -88,15 +101,19 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
88
101
  hf_proxies (Dict, optional): Proxies to use for downloading from Hugging Face Hub.
89
102
  """
90
103
  self.cache_state_dict = cache_state_dict
104
+
105
+ # Validate that both meta_module_class and meta_module are not provided
106
+ if meta_module_class is not None and meta_module is not None:
107
+ raise ValueError(
108
+ "Cannot provide both meta_module_class and meta_module, please provide only one."
109
+ )
110
+
91
111
  self.meta_module_class = meta_module_class
92
112
  if isinstance(self.meta_module_class, str):
93
113
  self.meta_module_class = import_object(self.meta_module_class)
94
114
  self.meta_module = meta_module
115
+
95
116
  if self.meta_module_class is not None:
96
- if self.meta_module is not None:
97
- raise ValueError(
98
- "Cannot provide both meta_module_class and meta_module, please provide only one."
99
- )
100
117
  with init_empty_weights():
101
118
  self.meta_module = self.meta_module_class.from_pretrained(
102
119
  checkpoint,
@@ -173,9 +190,13 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
173
190
  """
174
191
  `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
175
192
  """
193
+ if hasattr(self, "_cached_dtype"):
194
+ return self._cached_dtype
195
+
176
196
  first_key = next(iter(self.keys()))
177
197
  first_param = self[first_key]
178
- return first_param.dtype
198
+ self._cached_dtype = first_param.dtype
199
+ return self._cached_dtype
179
200
 
180
201
  def state_dict(self, keep_vars: bool = False) -> "LazyStateDict":
181
202
  """
@@ -321,9 +342,7 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
321
342
  if self._state_dict_cache is not None:
322
343
  self._state_dict_cache[key] = value
323
344
  else:
324
- log.warning(
325
- "State dict cache is disabled, setting a tensor will not update the cache."
326
- )
345
+ log.warning("State dict cache is disabled, initializing the cache.")
327
346
  self._state_dict_cache = {key: value}
328
347
 
329
348
  def __contains__(self, key: str) -> bool:
@@ -339,7 +358,7 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
339
358
  self._checkpoint_files[0], key, update_cache=False
340
359
  )
341
360
  return tensor is not None
342
- except Exception:
361
+ except (KeyError, FileNotFoundError, RuntimeError, EOFError):
343
362
  return False
344
363
  return False
345
364
 
@@ -409,8 +428,8 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
409
428
  )
410
429
 
411
430
  def load_state_dict(
412
- self, state_dict: Dict[str, torch.Tensor], strict: bool = True
413
- ) -> None:
431
+ self, state_dict: Mapping[str, torch.Tensor], strict: bool = True
432
+ ) -> _IncompatibleKeys:
414
433
  """
415
434
  Load a state dict into this LazyStateDict.
416
435
  This method is only for compatibility with nn.Module and it overrides the cache of LazyStateDict.
@@ -419,16 +438,60 @@ class LazyStateDict(Mapping[str, torch.Tensor]):
419
438
  state_dict (Dict[str, torch.Tensor]): The state dict to load.
420
439
  strict (bool): Whether to enforce that all keys in the state dict are present in this LazyStateDict.
421
440
  """
441
+ if not isinstance(state_dict, Mapping):
442
+ raise TypeError(
443
+ f"Expected state_dict to be dict-like, got {type(state_dict)}."
444
+ )
445
+
446
+ missing_keys: list[str] = []
447
+ unexpected_keys: list[str] = []
448
+ error_msgs: list[str] = []
449
+
422
450
  log.warning(
423
451
  "Loading state dict into LazyStateDict is not recommended, as it may lead to unexpected behavior. "
424
452
  "Use with caution."
425
453
  )
454
+
455
+ # Check for unexpected keys in the provided state_dict
456
+ for key in state_dict:
457
+ if key not in self:
458
+ unexpected_keys.append(key)
459
+
460
+ # Check for missing keys that are expected in this LazyStateDict
461
+ for key in self.keys():
462
+ if key not in state_dict:
463
+ missing_keys.append(key)
464
+
465
+ # Handle strict mode
426
466
  if strict:
427
- for key in state_dict:
428
- if key not in self:
429
- raise KeyError(f"Key {key} not found in LazyStateDict.")
467
+ if len(unexpected_keys) > 0:
468
+ error_msgs.insert(
469
+ 0,
470
+ "Unexpected key(s) in state_dict: {}. ".format(
471
+ ", ".join(f'"{k}"' for k in unexpected_keys)
472
+ ),
473
+ )
474
+ if len(missing_keys) > 0:
475
+ error_msgs.insert(
476
+ 0,
477
+ "Missing key(s) in state_dict: {}. ".format(
478
+ ", ".join(f'"{k}"' for k in missing_keys)
479
+ ),
480
+ )
481
+
482
+ if len(error_msgs) > 0:
483
+ raise RuntimeError(
484
+ "Error(s) in loading state_dict for {}:\n\t{}".format(
485
+ self.__class__.__name__, "\n\t".join(error_msgs)
486
+ )
487
+ )
488
+
489
+ # Load the state dict values
430
490
  for key, value in state_dict.items():
431
- self[key] = value
491
+ if key in self: # Only set keys that exist in this LazyStateDict
492
+ self[key] = value
493
+
494
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
432
495
 
433
496
  def __getattr__(self, name: str):
434
497
  if "meta_module" in self.__dict__:
@@ -1,7 +1,7 @@
1
1
  import importlib.metadata
2
2
  import importlib.util
3
3
  from functools import lru_cache
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Any
5
5
 
6
6
  from packaging import version
7
7
 
@@ -69,7 +69,7 @@ def is_vllm_available():
69
69
  return _is_package_available("vllm")
70
70
 
71
71
 
72
- def import_object(abs_obj_name: str):
72
+ def import_object(abs_obj_name: str) -> Any:
73
73
  """
74
74
  Imports a class from a module given the absolute class name.
75
75
 
@@ -84,7 +84,7 @@ def import_object(abs_obj_name: str):
84
84
  return getattr(module, obj_name)
85
85
 
86
86
 
87
- def compare_versions(v1, v2):
87
+ def compare_versions(v1: str, v2: str) -> int:
88
88
  """Compare two version strings.
89
89
  Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2"""
90
90
 
@@ -129,7 +129,6 @@ def human_readable(num: int) -> str:
129
129
  Converts a number into a human-readable string with appropriate magnitude suffix.
130
130
 
131
131
  Examples:
132
-
133
132
  ```python
134
133
  print(human_readable(1500))
135
134
  # Output: '1.50K'
@@ -201,7 +200,6 @@ def count_parameters(module: nn.Module, non_zero_only: bool = False) -> tuple[in
201
200
  tuple: A tuple containing the number of trainable parameters and the total number of parameters.
202
201
 
203
202
  Examples:
204
-
205
203
  ```python
206
204
  # Count the parameters
207
205
  trainable_params, all_params = count_parameters(model)
@@ -6,38 +6,120 @@ log = logging.getLogger(__name__)
6
6
 
7
7
  class timeit_context:
8
8
  """
9
- Usage:
9
+ A context manager for measuring and logging execution time of code blocks.
10
10
 
11
- ```python
12
- with timeit_context() as timer:
13
- ... # code block to be measured
14
- ```
11
+ This context manager provides precise timing measurements with automatic logging
12
+ of elapsed time. It supports nested timing contexts with proper indentation
13
+ for hierarchical timing analysis, making it ideal for profiling complex
14
+ operations with multiple sub-components.
15
+
16
+ Args:
17
+ msg (str, optional): Custom message to identify the timed code block.
18
+ If provided, logs "[BEGIN] {msg}" at start and includes context
19
+ in the final timing report. Defaults to None.
20
+ loglevel (int, optional): Python logging level for output messages.
21
+ Uses standard logging levels (DEBUG=10, INFO=20, WARNING=30, etc.).
22
+ Defaults to logging.INFO.
23
+
24
+ Example:
25
+ Basic usage:
26
+ ```python
27
+ with timeit_context("data loading"):
28
+ data = load_large_dataset()
29
+ # Logs: [BEGIN] data loading
30
+ # Logs: [END] Elapsed time: 2.34s
31
+ ```
32
+
33
+ Nested timing:
34
+ ```python
35
+ with timeit_context("model training"):
36
+ with timeit_context("data preprocessing"):
37
+ preprocess_data()
38
+ with timeit_context("forward pass"):
39
+ model(data)
40
+ # Output shows nested structure:
41
+ # [BEGIN] model training
42
+ # [BEGIN] data preprocessing
43
+ # [END] Elapsed time: 0.15s
44
+ # [BEGIN] forward pass
45
+ # [END] Elapsed time: 0.89s
46
+ # [END] Elapsed time: 1.04s
47
+ ```
48
+
49
+ Custom log level:
50
+ ```python
51
+ with timeit_context("debug operation", loglevel=logging.DEBUG):
52
+ debug_function()
53
+ ```
15
54
  """
16
55
 
17
56
  nest_level = -1
18
57
 
19
58
  def _log(self, msg):
59
+ """
60
+ Internal method for logging messages with appropriate stack level.
61
+
62
+ This helper method ensures that log messages appear to originate from
63
+ the caller's code rather than from internal timer methods, providing
64
+ more useful debugging information.
65
+
66
+ Args:
67
+ msg (str): The message to log at the configured log level.
68
+ """
20
69
  log.log(self.loglevel, msg, stacklevel=3)
21
70
 
22
71
  def __init__(self, msg: str = None, loglevel=logging.INFO) -> None:
72
+ """
73
+ Initialize a new timing context with optional message and log level.
74
+
75
+ Args:
76
+ msg (str, optional): Descriptive message for the timed operation.
77
+ If provided, will be included in the begin/end log messages
78
+ to help identify what is being timed. Defaults to None.
79
+ loglevel (int, optional): Python logging level for timer output.
80
+ Common values include:
81
+ - logging.DEBUG (10): Detailed debugging information
82
+ - logging.INFO (20): General information (default)
83
+ - logging.WARNING (30): Warning messages
84
+ - logging.ERROR (40): Error messages
85
+ Defaults to logging.INFO.
86
+ """
23
87
  self.loglevel = loglevel
24
88
  self.msg = msg
25
89
 
26
90
  def __enter__(self) -> None:
27
91
  """
28
- Sets the start time and logs an optional message indicating the start of the code block execution.
92
+ Enter the timing context and start the timer.
29
93
 
30
- Args:
31
- msg: str, optional message to log
94
+ This method is automatically called when entering the 'with' statement.
95
+ It records the current timestamp, increments the nesting level for
96
+ proper log indentation, and optionally logs a begin message.
97
+
98
+ Returns:
99
+ None: This context manager doesn't return a value to the 'as' clause.
100
+ All timing information is handled internally and logged automatically.
32
101
  """
33
102
  self.start_time = time.time()
34
103
  timeit_context.nest_level += 1
35
104
  if self.msg is not None:
36
105
  self._log(" " * timeit_context.nest_level + "[BEGIN] " + str(self.msg))
37
106
 
38
- def __exit__(self, exc_type, exc_val, exc_tb):
107
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
39
108
  """
40
- Calculates the elapsed time and logs it, along with an optional message indicating the end of the code block execution.
109
+ Exit the timing context and log the elapsed time.
110
+
111
+ This method is automatically called when exiting the 'with' statement,
112
+ whether through normal completion or exception. It calculates the total
113
+ elapsed time and logs the results with proper nesting indentation.
114
+
115
+ Args:
116
+ exc_type (type): Exception type if an exception occurred, None otherwise.
117
+ exc_val (Exception): Exception instance if an exception occurred, None otherwise.
118
+ exc_tb (traceback): Exception traceback if an exception occurred, None otherwise.
119
+
120
+ Returns:
121
+ None: Does not suppress exceptions (returns None/False implicitly).
122
+ Any exceptions that occurred in the timed block will propagate normally.
41
123
  """
42
124
  end_time = time.time()
43
125
  elapsed_time = end_time - self.start_time
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fusion_bench
3
- Version: 0.2.22
3
+ Version: 0.2.23
4
4
  Summary: A Comprehensive Benchmark of Deep Model Fusion
5
5
  Author-email: Anke Tang <tang.anke@foxmail.com>
6
6
  Project-URL: Repository, https://github.com/tanganke/fusion_bench