fusion-bench 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +25 -2
- fusion_bench/compat/method/__init__.py +5 -2
- fusion_bench/compat/method/base_algorithm.py +3 -2
- fusion_bench/compat/modelpool/base_pool.py +3 -3
- fusion_bench/compat/taskpool/clip_image_classification.py +1 -1
- fusion_bench/constants/__init__.py +1 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/gpt2_glue.py +1 -1
- fusion_bench/method/__init__.py +12 -4
- fusion_bench/method/analysis/task_vector_cos_similarity.py +95 -12
- fusion_bench/method/analysis/task_vector_violin_plot.py +160 -52
- fusion_bench/method/bitdelta/__init__.py +1 -0
- fusion_bench/method/bitdelta/bitdelta.py +7 -23
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/expert_sparsity/mixtral/dynamic_skipping.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/layer_wise_pruning.py +2 -0
- fusion_bench/method/expert_sparsity/mixtral/progressive_pruning.py +2 -0
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +0 -4
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +2 -2
- fusion_bench/method/linear/simple_average_for_llama.py +16 -11
- fusion_bench/method/model_stock/__init__.py +1 -0
- fusion_bench/method/model_stock/model_stock.py +309 -0
- fusion_bench/method/regmean/clip_regmean.py +3 -6
- fusion_bench/method/regmean/regmean.py +27 -56
- fusion_bench/method/regmean/utils.py +56 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +21 -60
- fusion_bench/method/simple_average.py +7 -7
- fusion_bench/method/slerp/__init__.py +1 -1
- fusion_bench/method/slerp/slerp.py +110 -14
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +1 -2
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +40 -31
- fusion_bench/method/smile_upscaling/smile_upscaling.py +1 -1
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +320 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/weighted_average/llama.py +1 -1
- fusion_bench/mixins/clip_classification.py +37 -48
- fusion_bench/mixins/serialization.py +30 -10
- fusion_bench/modelpool/base_pool.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +293 -75
- fusion_bench/modelpool/seq2seq_lm/modelpool.py +146 -0
- fusion_bench/models/__init__.py +5 -0
- fusion_bench/models/hf_utils.py +69 -86
- fusion_bench/models/linearized/vision_model.py +6 -6
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +1 -8
- fusion_bench/models/modeling_smile_mistral/__init__.py +2 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +1 -5
- fusion_bench/models/we_moe.py +8 -8
- fusion_bench/programs/fabric_fusion_program.py +29 -60
- fusion_bench/scripts/cli.py +34 -1
- fusion_bench/taskpool/base_pool.py +99 -17
- fusion_bench/taskpool/clip_vision/taskpool.py +10 -5
- fusion_bench/taskpool/dummy.py +101 -13
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +80 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +28 -0
- fusion_bench/utils/__init__.py +2 -0
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/data.py +6 -4
- fusion_bench/utils/devices.py +7 -4
- fusion_bench/utils/dtype.py +3 -2
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +117 -19
- fusion_bench/utils/modelscope.py +3 -3
- fusion_bench/utils/packages.py +3 -3
- fusion_bench/utils/parameters.py +0 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/timer.py +92 -10
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/METADATA +1 -23
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/RECORD +89 -75
- fusion_bench_config/_get_started/llm_slerp.yaml +12 -0
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/model_stock/model_stock.yaml +12 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +4 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +1 -1
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.21.dist-info → fusion_bench-0.2.23.dist-info}/top_level.txt +0 -0
fusion_bench/utils/devices.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import gc
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
from typing import List, Optional, Union
|
|
4
|
+
from typing import Any, List, Optional, Union
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
from transformers.utils import (
|
|
@@ -12,6 +12,8 @@ from transformers.utils import (
|
|
|
12
12
|
is_torch_xpu_available,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
|
+
from .type import T
|
|
16
|
+
|
|
15
17
|
__all__ = [
|
|
16
18
|
"clear_cuda_cache",
|
|
17
19
|
"to_device",
|
|
@@ -37,7 +39,7 @@ def clear_cuda_cache():
|
|
|
37
39
|
log.warning("CUDA is not available. No cache to clear.")
|
|
38
40
|
|
|
39
41
|
|
|
40
|
-
def to_device(obj, device: Optional[torch.device], **kwargs):
|
|
42
|
+
def to_device(obj: T, device: Optional[torch.device], **kwargs: Any) -> T:
|
|
41
43
|
"""
|
|
42
44
|
Move a given object to the specified device.
|
|
43
45
|
|
|
@@ -102,7 +104,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
|
|
|
102
104
|
)
|
|
103
105
|
|
|
104
106
|
|
|
105
|
-
def get_device(obj) -> torch.device:
|
|
107
|
+
def get_device(obj: Any) -> torch.device:
|
|
106
108
|
"""
|
|
107
109
|
Get the device of a given object.
|
|
108
110
|
|
|
@@ -151,6 +153,7 @@ def get_current_device() -> torch.device:
|
|
|
151
153
|
If not set, it defaults to "0".
|
|
152
154
|
|
|
153
155
|
Example:
|
|
156
|
+
|
|
154
157
|
>>> device = get_current_device()
|
|
155
158
|
>>> print(device)
|
|
156
159
|
xpu:0 # or npu:0, mps:0, cuda:0, cpu depending on availability
|
|
@@ -241,7 +244,7 @@ def cleanup_cuda():
|
|
|
241
244
|
torch.cuda.reset_peak_memory_stats()
|
|
242
245
|
|
|
243
246
|
|
|
244
|
-
def print_memory_usage(print_fn=print):
|
|
247
|
+
def print_memory_usage(print_fn=print) -> str:
|
|
245
248
|
"""
|
|
246
249
|
Print the current GPU memory usage.
|
|
247
250
|
|
fusion_bench/utils/dtype.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import contextlib
|
|
2
|
-
from typing import Dict, Generator, Iterable, Optional, Tuple
|
|
2
|
+
from typing import Dict, Generator, Iterable, Optional, Tuple, Union
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
from transformers.utils import (
|
|
@@ -25,7 +25,7 @@ PRECISION_STR_TO_DTYPE: Dict[str, torch.dtype] = {
|
|
|
25
25
|
}
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
def parse_dtype(dtype: Optional[str]):
|
|
28
|
+
def parse_dtype(dtype: Optional[str]) -> Optional[torch.dtype]:
|
|
29
29
|
"""
|
|
30
30
|
Parses a string representation of a data type and returns the corresponding torch.dtype.
|
|
31
31
|
|
|
@@ -92,6 +92,7 @@ def set_default_dtype(dtype: torch.dtype) -> Generator[None, None, None]:
|
|
|
92
92
|
ContextManager: context manager for setting default dtype.
|
|
93
93
|
|
|
94
94
|
Example:
|
|
95
|
+
|
|
95
96
|
>>> with set_default_dtype(torch.bfloat16):
|
|
96
97
|
>>> x = torch.tensor([1, 2, 3])
|
|
97
98
|
>>> x.dtype
|
fusion_bench/utils/fabric.py
CHANGED
|
@@ -3,9 +3,9 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import lightning as L
|
|
5
5
|
|
|
6
|
-
from fusion_bench.utils.pylogger import
|
|
6
|
+
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
7
7
|
|
|
8
|
-
log =
|
|
8
|
+
log = get_rankzero_logger(__name__)
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def seed_everything_by_time(fabric: Optional[L.Fabric] = None):
|
|
@@ -72,3 +72,26 @@ class LazyImporter(ModuleType):
|
|
|
72
72
|
|
|
73
73
|
def __reduce__(self):
|
|
74
74
|
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class LazyModule(ModuleType):
|
|
78
|
+
"""Module wrapper for lazy import.
|
|
79
|
+
Adapted from Optuna: https://github.com/optuna/optuna/blob/1f92d496b0c4656645384e31539e4ee74992ff55/optuna/__init__.py
|
|
80
|
+
|
|
81
|
+
This class wraps specified module and lazily import it when they are actually accessed.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
name: Name of module to apply lazy import.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self, name: str) -> None:
|
|
88
|
+
super().__init__(name)
|
|
89
|
+
self._name = name
|
|
90
|
+
|
|
91
|
+
def _load(self) -> ModuleType:
|
|
92
|
+
module = importlib.import_module(self._name)
|
|
93
|
+
self.__dict__.update(module.__dict__)
|
|
94
|
+
return module
|
|
95
|
+
|
|
96
|
+
def __getattr__(self, item: str) -> Any:
|
|
97
|
+
return getattr(self._load(), item)
|
|
@@ -2,7 +2,18 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Dict,
|
|
8
|
+
Generic,
|
|
9
|
+
Iterator,
|
|
10
|
+
List,
|
|
11
|
+
Mapping,
|
|
12
|
+
Optional,
|
|
13
|
+
Tuple,
|
|
14
|
+
Type,
|
|
15
|
+
Union,
|
|
16
|
+
)
|
|
6
17
|
|
|
7
18
|
import torch
|
|
8
19
|
from accelerate import init_empty_weights
|
|
@@ -11,10 +22,12 @@ from huggingface_hub import snapshot_download
|
|
|
11
22
|
from safetensors import safe_open
|
|
12
23
|
from safetensors.torch import load_file
|
|
13
24
|
from torch import nn
|
|
25
|
+
from torch.nn.modules.module import _IncompatibleKeys
|
|
14
26
|
from transformers import AutoConfig
|
|
15
27
|
|
|
16
28
|
from fusion_bench.utils.dtype import parse_dtype
|
|
17
29
|
from fusion_bench.utils.packages import import_object
|
|
30
|
+
from fusion_bench.utils.type import TorchModelType
|
|
18
31
|
|
|
19
32
|
if TYPE_CHECKING:
|
|
20
33
|
from transformers import PretrainedConfig
|
|
@@ -49,7 +62,7 @@ def resolve_checkpoint_path(
|
|
|
49
62
|
)
|
|
50
63
|
|
|
51
64
|
|
|
52
|
-
class LazyStateDict:
|
|
65
|
+
class LazyStateDict(Mapping[str, torch.Tensor], Generic[TorchModelType]):
|
|
53
66
|
"""
|
|
54
67
|
Dictionary-like object that lazily loads a state dict from a checkpoint path.
|
|
55
68
|
"""
|
|
@@ -66,8 +79,8 @@ class LazyStateDict:
|
|
|
66
79
|
def __init__(
|
|
67
80
|
self,
|
|
68
81
|
checkpoint: str,
|
|
69
|
-
meta_module_class: Optional[Type[
|
|
70
|
-
meta_module: Optional[
|
|
82
|
+
meta_module_class: Optional[Type[TorchModelType]] = None,
|
|
83
|
+
meta_module: Optional[TorchModelType] = None,
|
|
71
84
|
cache_state_dict: bool = False,
|
|
72
85
|
torch_dtype: Optional[torch.dtype] = None,
|
|
73
86
|
device: str = "cpu",
|
|
@@ -88,15 +101,19 @@ class LazyStateDict:
|
|
|
88
101
|
hf_proxies (Dict, optional): Proxies to use for downloading from Hugging Face Hub.
|
|
89
102
|
"""
|
|
90
103
|
self.cache_state_dict = cache_state_dict
|
|
104
|
+
|
|
105
|
+
# Validate that both meta_module_class and meta_module are not provided
|
|
106
|
+
if meta_module_class is not None and meta_module is not None:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
"Cannot provide both meta_module_class and meta_module, please provide only one."
|
|
109
|
+
)
|
|
110
|
+
|
|
91
111
|
self.meta_module_class = meta_module_class
|
|
92
112
|
if isinstance(self.meta_module_class, str):
|
|
93
113
|
self.meta_module_class = import_object(self.meta_module_class)
|
|
94
114
|
self.meta_module = meta_module
|
|
115
|
+
|
|
95
116
|
if self.meta_module_class is not None:
|
|
96
|
-
if self.meta_module is not None:
|
|
97
|
-
raise ValueError(
|
|
98
|
-
"Cannot provide both meta_module_class and meta_module, please provide only one."
|
|
99
|
-
)
|
|
100
117
|
with init_empty_weights():
|
|
101
118
|
self.meta_module = self.meta_module_class.from_pretrained(
|
|
102
119
|
checkpoint,
|
|
@@ -168,12 +185,25 @@ class LazyStateDict:
|
|
|
168
185
|
def config(self) -> "PretrainedConfig":
|
|
169
186
|
return AutoConfig.from_pretrained(self._checkpoint)
|
|
170
187
|
|
|
188
|
+
@property
|
|
189
|
+
def dtype(self) -> torch.dtype:
|
|
190
|
+
"""
|
|
191
|
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
|
192
|
+
"""
|
|
193
|
+
if hasattr(self, "_cached_dtype"):
|
|
194
|
+
return self._cached_dtype
|
|
195
|
+
|
|
196
|
+
first_key = next(iter(self.keys()))
|
|
197
|
+
first_param = self[first_key]
|
|
198
|
+
self._cached_dtype = first_param.dtype
|
|
199
|
+
return self._cached_dtype
|
|
200
|
+
|
|
171
201
|
def state_dict(self, keep_vars: bool = False) -> "LazyStateDict":
|
|
172
202
|
"""
|
|
173
203
|
Args:
|
|
174
204
|
keep_vars (bool): Ignored, as LazyStateDict does not support keep_vars. Just for compatibility.
|
|
175
205
|
"""
|
|
176
|
-
return self
|
|
206
|
+
return deepcopy(self)
|
|
177
207
|
|
|
178
208
|
def _resolve_checkpoint_files(self, checkpoint: str):
|
|
179
209
|
# reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
|
|
@@ -290,6 +320,18 @@ class LazyStateDict:
|
|
|
290
320
|
)
|
|
291
321
|
return tensor
|
|
292
322
|
|
|
323
|
+
def pop(self, key: str):
|
|
324
|
+
assert key in list(
|
|
325
|
+
self.keys()
|
|
326
|
+
), "KeyError: Cannot pop a tensor for a key that does not exist in the LazyStateDict."
|
|
327
|
+
if self._state_dict_cache is not None and key in self._state_dict_cache:
|
|
328
|
+
if key in self._index:
|
|
329
|
+
self._index.pop(key)
|
|
330
|
+
return self._state_dict_cache.pop(key)
|
|
331
|
+
if key in self._index:
|
|
332
|
+
self._index.pop(key)
|
|
333
|
+
return None
|
|
334
|
+
|
|
293
335
|
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
|
294
336
|
"""
|
|
295
337
|
Set a tensor in the LazyStateDict. This will update the state dict cache if it is enabled.
|
|
@@ -300,9 +342,7 @@ class LazyStateDict:
|
|
|
300
342
|
if self._state_dict_cache is not None:
|
|
301
343
|
self._state_dict_cache[key] = value
|
|
302
344
|
else:
|
|
303
|
-
log.warning(
|
|
304
|
-
"State dict cache is disabled, setting a tensor will not update the cache."
|
|
305
|
-
)
|
|
345
|
+
log.warning("State dict cache is disabled, initializing the cache.")
|
|
306
346
|
self._state_dict_cache = {key: value}
|
|
307
347
|
|
|
308
348
|
def __contains__(self, key: str) -> bool:
|
|
@@ -318,7 +358,7 @@ class LazyStateDict:
|
|
|
318
358
|
self._checkpoint_files[0], key, update_cache=False
|
|
319
359
|
)
|
|
320
360
|
return tensor is not None
|
|
321
|
-
except
|
|
361
|
+
except (KeyError, FileNotFoundError, RuntimeError, EOFError):
|
|
322
362
|
return False
|
|
323
363
|
return False
|
|
324
364
|
|
|
@@ -388,8 +428,8 @@ class LazyStateDict:
|
|
|
388
428
|
)
|
|
389
429
|
|
|
390
430
|
def load_state_dict(
|
|
391
|
-
self, state_dict:
|
|
392
|
-
) ->
|
|
431
|
+
self, state_dict: Mapping[str, torch.Tensor], strict: bool = True
|
|
432
|
+
) -> _IncompatibleKeys:
|
|
393
433
|
"""
|
|
394
434
|
Load a state dict into this LazyStateDict.
|
|
395
435
|
This method is only for compatibility with nn.Module and it overrides the cache of LazyStateDict.
|
|
@@ -398,13 +438,71 @@ class LazyStateDict:
|
|
|
398
438
|
state_dict (Dict[str, torch.Tensor]): The state dict to load.
|
|
399
439
|
strict (bool): Whether to enforce that all keys in the state dict are present in this LazyStateDict.
|
|
400
440
|
"""
|
|
441
|
+
if not isinstance(state_dict, Mapping):
|
|
442
|
+
raise TypeError(
|
|
443
|
+
f"Expected state_dict to be dict-like, got {type(state_dict)}."
|
|
444
|
+
)
|
|
445
|
+
|
|
446
|
+
missing_keys: list[str] = []
|
|
447
|
+
unexpected_keys: list[str] = []
|
|
448
|
+
error_msgs: list[str] = []
|
|
449
|
+
|
|
401
450
|
log.warning(
|
|
402
451
|
"Loading state dict into LazyStateDict is not recommended, as it may lead to unexpected behavior. "
|
|
403
452
|
"Use with caution."
|
|
404
453
|
)
|
|
454
|
+
|
|
455
|
+
# Check for unexpected keys in the provided state_dict
|
|
456
|
+
for key in state_dict:
|
|
457
|
+
if key not in self:
|
|
458
|
+
unexpected_keys.append(key)
|
|
459
|
+
|
|
460
|
+
# Check for missing keys that are expected in this LazyStateDict
|
|
461
|
+
for key in self.keys():
|
|
462
|
+
if key not in state_dict:
|
|
463
|
+
missing_keys.append(key)
|
|
464
|
+
|
|
465
|
+
# Handle strict mode
|
|
405
466
|
if strict:
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
467
|
+
if len(unexpected_keys) > 0:
|
|
468
|
+
error_msgs.insert(
|
|
469
|
+
0,
|
|
470
|
+
"Unexpected key(s) in state_dict: {}. ".format(
|
|
471
|
+
", ".join(f'"{k}"' for k in unexpected_keys)
|
|
472
|
+
),
|
|
473
|
+
)
|
|
474
|
+
if len(missing_keys) > 0:
|
|
475
|
+
error_msgs.insert(
|
|
476
|
+
0,
|
|
477
|
+
"Missing key(s) in state_dict: {}. ".format(
|
|
478
|
+
", ".join(f'"{k}"' for k in missing_keys)
|
|
479
|
+
),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
if len(error_msgs) > 0:
|
|
483
|
+
raise RuntimeError(
|
|
484
|
+
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
|
485
|
+
self.__class__.__name__, "\n\t".join(error_msgs)
|
|
486
|
+
)
|
|
487
|
+
)
|
|
488
|
+
|
|
489
|
+
# Load the state dict values
|
|
409
490
|
for key, value in state_dict.items():
|
|
410
|
-
|
|
491
|
+
if key in self: # Only set keys that exist in this LazyStateDict
|
|
492
|
+
self[key] = value
|
|
493
|
+
|
|
494
|
+
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
|
495
|
+
|
|
496
|
+
def __getattr__(self, name: str):
|
|
497
|
+
if "meta_module" in self.__dict__:
|
|
498
|
+
meta_module = self.__dict__["meta_module"]
|
|
499
|
+
if meta_module is not None:
|
|
500
|
+
if "_parameters" in meta_module.__dict__:
|
|
501
|
+
if name in meta_module.__dict__["_parameters"]:
|
|
502
|
+
return self.get_parameter(name)
|
|
503
|
+
if "_modules" in meta_module.__dict__:
|
|
504
|
+
if name in meta_module.__dict__["_modules"]:
|
|
505
|
+
return self.get_submodule(name)
|
|
506
|
+
raise AttributeError(
|
|
507
|
+
f"'{type(self).__name__}' object has no attribute '{name}'"
|
|
508
|
+
)
|
fusion_bench/utils/modelscope.py
CHANGED
|
@@ -26,13 +26,13 @@ try:
|
|
|
26
26
|
from huggingface_hub import snapshot_download as huggingface_snapshot_download
|
|
27
27
|
except ImportError:
|
|
28
28
|
|
|
29
|
-
def
|
|
29
|
+
def _raise_huggingface_not_installed_error(*args, **kwargs):
|
|
30
30
|
raise ImportError(
|
|
31
31
|
"Hugging Face Hub is not installed. Please install it using `pip install huggingface_hub` to use Hugging Face models."
|
|
32
32
|
)
|
|
33
33
|
|
|
34
|
-
huggingface_snapshot_download =
|
|
35
|
-
hf_hub_download =
|
|
34
|
+
huggingface_snapshot_download = _raise_huggingface_not_installed_error
|
|
35
|
+
hf_hub_download = _raise_huggingface_not_installed_error
|
|
36
36
|
|
|
37
37
|
__all__ = [
|
|
38
38
|
"load_dataset",
|
fusion_bench/utils/packages.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import importlib.metadata
|
|
2
2
|
import importlib.util
|
|
3
3
|
from functools import lru_cache
|
|
4
|
-
from typing import TYPE_CHECKING
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
5
|
|
|
6
6
|
from packaging import version
|
|
7
7
|
|
|
@@ -69,7 +69,7 @@ def is_vllm_available():
|
|
|
69
69
|
return _is_package_available("vllm")
|
|
70
70
|
|
|
71
71
|
|
|
72
|
-
def import_object(abs_obj_name: str):
|
|
72
|
+
def import_object(abs_obj_name: str) -> Any:
|
|
73
73
|
"""
|
|
74
74
|
Imports a class from a module given the absolute class name.
|
|
75
75
|
|
|
@@ -84,7 +84,7 @@ def import_object(abs_obj_name: str):
|
|
|
84
84
|
return getattr(module, obj_name)
|
|
85
85
|
|
|
86
86
|
|
|
87
|
-
def compare_versions(v1, v2):
|
|
87
|
+
def compare_versions(v1: str, v2: str) -> int:
|
|
88
88
|
"""Compare two version strings.
|
|
89
89
|
Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2"""
|
|
90
90
|
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -129,7 +129,6 @@ def human_readable(num: int) -> str:
|
|
|
129
129
|
Converts a number into a human-readable string with appropriate magnitude suffix.
|
|
130
130
|
|
|
131
131
|
Examples:
|
|
132
|
-
|
|
133
132
|
```python
|
|
134
133
|
print(human_readable(1500))
|
|
135
134
|
# Output: '1.50K'
|
|
@@ -201,7 +200,6 @@ def count_parameters(module: nn.Module, non_zero_only: bool = False) -> tuple[in
|
|
|
201
200
|
tuple: A tuple containing the number of trainable parameters and the total number of parameters.
|
|
202
201
|
|
|
203
202
|
Examples:
|
|
204
|
-
|
|
205
203
|
```python
|
|
206
204
|
# Count the parameters
|
|
207
205
|
trainable_params, all_params = count_parameters(model)
|
fusion_bench/utils/path.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
1
|
+
import logging
|
|
1
2
|
import os
|
|
2
3
|
from typing import List
|
|
3
4
|
|
|
5
|
+
log = logging.getLogger(__name__)
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
def path_is_dir_and_not_empty(path: str):
|
|
6
9
|
if path is None:
|
|
@@ -20,3 +23,56 @@ def listdir_fullpath(dir: str) -> List[str]:
|
|
|
20
23
|
assert os.path.isdir(dir), "Argument 'dir' must be a Directory"
|
|
21
24
|
names = os.listdir(dir)
|
|
22
25
|
return [os.path.join(dir, name) for name in names]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def create_symlink(src_dir: str, dst_dir: str, link_name: str = None):
|
|
29
|
+
"""
|
|
30
|
+
Creates a symbolic link from src_dir to dst_dir.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
src_dir (str): The source directory to link to.
|
|
34
|
+
dst_dir (str): The destination directory where the symlink will be created.
|
|
35
|
+
link_name (str, optional): The name of the symlink. If None, uses the basename of src_dir.
|
|
36
|
+
|
|
37
|
+
Raises:
|
|
38
|
+
OSError: If the symbolic link creation fails.
|
|
39
|
+
ValueError: If src_dir does not exist or is not a directory.
|
|
40
|
+
"""
|
|
41
|
+
if not os.path.exists(src_dir):
|
|
42
|
+
raise ValueError(f"Source directory does not exist: {src_dir}")
|
|
43
|
+
|
|
44
|
+
if not os.path.isdir(src_dir):
|
|
45
|
+
raise ValueError(f"Source path is not a directory: {src_dir}")
|
|
46
|
+
|
|
47
|
+
# Avoid creating symlink if source and destination are the same
|
|
48
|
+
if os.path.abspath(src_dir) == os.path.abspath(dst_dir):
|
|
49
|
+
log.warning(
|
|
50
|
+
"Source and destination directories are the same, skipping symlink creation"
|
|
51
|
+
)
|
|
52
|
+
return
|
|
53
|
+
|
|
54
|
+
# Create destination directory if it doesn't exist
|
|
55
|
+
os.makedirs(dst_dir, exist_ok=True)
|
|
56
|
+
|
|
57
|
+
# Determine link name
|
|
58
|
+
if link_name is None:
|
|
59
|
+
link_name = os.path.basename(src_dir)
|
|
60
|
+
|
|
61
|
+
link_path = os.path.join(dst_dir, link_name)
|
|
62
|
+
|
|
63
|
+
try:
|
|
64
|
+
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
|
65
|
+
if os.name == "nt":
|
|
66
|
+
os.system(
|
|
67
|
+
f"mklink /J {os.path.abspath(link_path)} {os.path.abspath(src_dir)}"
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
os.symlink(
|
|
71
|
+
src_dir,
|
|
72
|
+
link_path,
|
|
73
|
+
target_is_directory=True,
|
|
74
|
+
)
|
|
75
|
+
log.info(f"Created symbolic link: {link_path} -> {src_dir}")
|
|
76
|
+
except OSError as e:
|
|
77
|
+
log.warning(f"Failed to create symbolic link: {e}")
|
|
78
|
+
raise
|
fusion_bench/utils/pylogger.py
CHANGED
|
@@ -74,7 +74,7 @@ RankZeroLogger.manager = logging.Manager(RankZeroLogger.root)
|
|
|
74
74
|
RankZeroLogger.manager.setLoggerClass(RankZeroLogger)
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def
|
|
77
|
+
def get_rankzero_logger(name=None):
|
|
78
78
|
"""
|
|
79
79
|
Return a logger with the specified name, creating it if necessary.
|
|
80
80
|
|
fusion_bench/utils/timer.py
CHANGED
|
@@ -6,38 +6,120 @@ log = logging.getLogger(__name__)
|
|
|
6
6
|
|
|
7
7
|
class timeit_context:
|
|
8
8
|
"""
|
|
9
|
-
|
|
9
|
+
A context manager for measuring and logging execution time of code blocks.
|
|
10
10
|
|
|
11
|
-
|
|
12
|
-
with
|
|
13
|
-
|
|
14
|
-
|
|
11
|
+
This context manager provides precise timing measurements with automatic logging
|
|
12
|
+
of elapsed time. It supports nested timing contexts with proper indentation
|
|
13
|
+
for hierarchical timing analysis, making it ideal for profiling complex
|
|
14
|
+
operations with multiple sub-components.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
msg (str, optional): Custom message to identify the timed code block.
|
|
18
|
+
If provided, logs "[BEGIN] {msg}" at start and includes context
|
|
19
|
+
in the final timing report. Defaults to None.
|
|
20
|
+
loglevel (int, optional): Python logging level for output messages.
|
|
21
|
+
Uses standard logging levels (DEBUG=10, INFO=20, WARNING=30, etc.).
|
|
22
|
+
Defaults to logging.INFO.
|
|
23
|
+
|
|
24
|
+
Example:
|
|
25
|
+
Basic usage:
|
|
26
|
+
```python
|
|
27
|
+
with timeit_context("data loading"):
|
|
28
|
+
data = load_large_dataset()
|
|
29
|
+
# Logs: [BEGIN] data loading
|
|
30
|
+
# Logs: [END] Elapsed time: 2.34s
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
Nested timing:
|
|
34
|
+
```python
|
|
35
|
+
with timeit_context("model training"):
|
|
36
|
+
with timeit_context("data preprocessing"):
|
|
37
|
+
preprocess_data()
|
|
38
|
+
with timeit_context("forward pass"):
|
|
39
|
+
model(data)
|
|
40
|
+
# Output shows nested structure:
|
|
41
|
+
# [BEGIN] model training
|
|
42
|
+
# [BEGIN] data preprocessing
|
|
43
|
+
# [END] Elapsed time: 0.15s
|
|
44
|
+
# [BEGIN] forward pass
|
|
45
|
+
# [END] Elapsed time: 0.89s
|
|
46
|
+
# [END] Elapsed time: 1.04s
|
|
47
|
+
```
|
|
48
|
+
|
|
49
|
+
Custom log level:
|
|
50
|
+
```python
|
|
51
|
+
with timeit_context("debug operation", loglevel=logging.DEBUG):
|
|
52
|
+
debug_function()
|
|
53
|
+
```
|
|
15
54
|
"""
|
|
16
55
|
|
|
17
56
|
nest_level = -1
|
|
18
57
|
|
|
19
58
|
def _log(self, msg):
|
|
59
|
+
"""
|
|
60
|
+
Internal method for logging messages with appropriate stack level.
|
|
61
|
+
|
|
62
|
+
This helper method ensures that log messages appear to originate from
|
|
63
|
+
the caller's code rather than from internal timer methods, providing
|
|
64
|
+
more useful debugging information.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
msg (str): The message to log at the configured log level.
|
|
68
|
+
"""
|
|
20
69
|
log.log(self.loglevel, msg, stacklevel=3)
|
|
21
70
|
|
|
22
71
|
def __init__(self, msg: str = None, loglevel=logging.INFO) -> None:
|
|
72
|
+
"""
|
|
73
|
+
Initialize a new timing context with optional message and log level.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
msg (str, optional): Descriptive message for the timed operation.
|
|
77
|
+
If provided, will be included in the begin/end log messages
|
|
78
|
+
to help identify what is being timed. Defaults to None.
|
|
79
|
+
loglevel (int, optional): Python logging level for timer output.
|
|
80
|
+
Common values include:
|
|
81
|
+
- logging.DEBUG (10): Detailed debugging information
|
|
82
|
+
- logging.INFO (20): General information (default)
|
|
83
|
+
- logging.WARNING (30): Warning messages
|
|
84
|
+
- logging.ERROR (40): Error messages
|
|
85
|
+
Defaults to logging.INFO.
|
|
86
|
+
"""
|
|
23
87
|
self.loglevel = loglevel
|
|
24
88
|
self.msg = msg
|
|
25
89
|
|
|
26
90
|
def __enter__(self) -> None:
|
|
27
91
|
"""
|
|
28
|
-
|
|
92
|
+
Enter the timing context and start the timer.
|
|
29
93
|
|
|
30
|
-
|
|
31
|
-
|
|
94
|
+
This method is automatically called when entering the 'with' statement.
|
|
95
|
+
It records the current timestamp, increments the nesting level for
|
|
96
|
+
proper log indentation, and optionally logs a begin message.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
None: This context manager doesn't return a value to the 'as' clause.
|
|
100
|
+
All timing information is handled internally and logged automatically.
|
|
32
101
|
"""
|
|
33
102
|
self.start_time = time.time()
|
|
34
103
|
timeit_context.nest_level += 1
|
|
35
104
|
if self.msg is not None:
|
|
36
105
|
self._log(" " * timeit_context.nest_level + "[BEGIN] " + str(self.msg))
|
|
37
106
|
|
|
38
|
-
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
107
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
39
108
|
"""
|
|
40
|
-
|
|
109
|
+
Exit the timing context and log the elapsed time.
|
|
110
|
+
|
|
111
|
+
This method is automatically called when exiting the 'with' statement,
|
|
112
|
+
whether through normal completion or exception. It calculates the total
|
|
113
|
+
elapsed time and logs the results with proper nesting indentation.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
exc_type (type): Exception type if an exception occurred, None otherwise.
|
|
117
|
+
exc_val (Exception): Exception instance if an exception occurred, None otherwise.
|
|
118
|
+
exc_tb (traceback): Exception traceback if an exception occurred, None otherwise.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
None: Does not suppress exceptions (returns None/False implicitly).
|
|
122
|
+
Any exceptions that occurred in the timed block will propagate normally.
|
|
41
123
|
"""
|
|
42
124
|
end_time = time.time()
|
|
43
125
|
elapsed_time = end_time - self.start_time
|
|
@@ -1,30 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.23
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
|
-
License: MIT License
|
|
7
|
-
|
|
8
|
-
Copyright (c) 2024 Anke Tang
|
|
9
|
-
|
|
10
|
-
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
-
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
-
in the Software without restriction, including without limitation the rights
|
|
13
|
-
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
-
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
-
furnished to do so, subject to the following conditions:
|
|
16
|
-
|
|
17
|
-
The above copyright notice and this permission notice shall be included in all
|
|
18
|
-
copies or substantial portions of the Software.
|
|
19
|
-
|
|
20
|
-
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
-
SOFTWARE.
|
|
27
|
-
|
|
28
6
|
Project-URL: Repository, https://github.com/tanganke/fusion_bench
|
|
29
7
|
Project-URL: Homepage, https://github.com/tanganke/fusion_bench
|
|
30
8
|
Project-URL: Issues, https://github.com/tanganke/fusion_bench/issues
|