fusion-bench 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/method/__init__.py +4 -0
- fusion_bench/method/adamerging/flan_t5_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/gpt2_layer_wise_adamerging.py +1 -1
- fusion_bench/method/base_algorithm.py +1 -0
- fusion_bench/method/dawe/dawe_for_clip.py +1 -1
- fusion_bench/method/depth_upscaling/depth_upscaling_for_llama.py +3 -2
- fusion_bench/method/fw_merging/__init__.py +2 -0
- fusion_bench/method/fw_merging/fw_hard.py +448 -0
- fusion_bench/method/fw_merging/fw_soft.py +519 -0
- fusion_bench/method/fw_merging/utils.py +331 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +1 -1
- fusion_bench/method/moe_pruner/__init__.py +7 -0
- fusion_bench/method/moe_pruner/hooks/__init__.py +6 -0
- fusion_bench/method/moe_pruner/hooks/deepseek_v2.py +85 -0
- fusion_bench/method/moe_pruner/hooks/hook.py +23 -0
- fusion_bench/method/moe_pruner/hooks/mixtral.py +93 -0
- fusion_bench/method/moe_pruner/moe_pruner.py +304 -0
- fusion_bench/method/moe_pruner/utils/__init__.py +1 -0
- fusion_bench/method/moe_pruner/utils/data.py +154 -0
- fusion_bench/method/moe_pruner/utils/layerwrapper.py +61 -0
- fusion_bench/method/moe_pruner/utils/prune.py +313 -0
- fusion_bench/method/moe_pruner/utils/score.py +41 -0
- fusion_bench/method/pruning/__init__.py +1 -0
- fusion_bench/method/pruning/llama_sparsegpt_prune.py +223 -0
- fusion_bench/method/pruning/sparsegpt_utils/__init__.py +1 -0
- fusion_bench/method/pruning/sparsegpt_utils/sparsegpt.py +128 -0
- fusion_bench/method/pruning/wanda_utils/data.py +33 -14
- fusion_bench/method/pwe_moe/module.py +2 -7
- fusion_bench/method/randes/__init__.py +15 -0
- fusion_bench/method/randes/base_algorithm.py +1013 -0
- fusion_bench/method/randes/modelsoup.py +126 -0
- fusion_bench/method/randes/task_arithmetic.py +318 -0
- fusion_bench/method/simple_average.py +3 -2
- fusion_bench/method/sparselo/sparselo.py +20 -2
- fusion_bench/method/tall_mask/__init__.py +1 -0
- fusion_bench/method/tall_mask/task_arithmetic.py +133 -0
- fusion_bench/method/task_singular_vector/TSVM.py +238 -25
- fusion_bench/method/task_singular_vector/utils/TSVM_utils.py +52 -20
- fusion_bench/mixins/hydra_config.py +1 -1
- fusion_bench/mixins/lightning_fabric.py +25 -1
- fusion_bench/mixins/serialization.py +18 -2
- fusion_bench/modelpool/base_pool.py +1 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +21 -13
- fusion_bench/modelpool/lazy_state_dict_pool.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +15 -0
- fusion_bench/models/modeling_deepseek_v2/configuration_deepseek.py +208 -0
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +1922 -0
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +38 -0
- fusion_bench/models/parameter_dict.py +6 -1
- fusion_bench/programs/fabric_fusion_program.py +14 -5
- fusion_bench/taskpool/base_pool.py +1 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +8 -1
- fusion_bench/taskpool/dummy.py +6 -4
- fusion_bench/utils/__init__.py +2 -1
- fusion_bench/utils/data.py +1 -1
- fusion_bench/utils/{instantiate.py → instantiate_utils.py} +3 -0
- fusion_bench/utils/lazy_state_dict.py +268 -0
- fusion_bench/utils/parameters.py +33 -0
- fusion_bench/utils/pylogger.py +28 -0
- fusion_bench/utils/state_dict_arithmetic.py +74 -2
- fusion_bench/utils/type.py +1 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/METADATA +8 -2
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/RECORD +104 -44
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/WHEEL +1 -1
- fusion_bench_config/dataset/image_classification/test/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/test/TALL18.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL10.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL12.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL16.yaml +28 -0
- fusion_bench_config/dataset/image_classification/train/TALL18.yaml +28 -0
- fusion_bench_config/fabric_model_fusion.yaml +2 -2
- fusion_bench_config/method/fw_merging/fw_hard.yaml +11 -0
- fusion_bench_config/method/fw_merging/fw_soft.yaml +12 -0
- fusion_bench_config/method/moe_pruner/moe_pruner.yaml +15 -0
- fusion_bench_config/method/pruning/llama_sparsegpt_pruning.yaml +16 -0
- fusion_bench_config/method/randes/superposed_model_soup.yaml +18 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic.yaml +20 -0
- fusion_bench_config/method/randes/superposed_task_arithmetic_lora.yaml +20 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +2 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/tall_mask/task_arithmetic.yaml +4 -0
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +2 -1
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL10.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL12.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL16.yaml +29 -0
- fusion_bench_config/model/clip-vit/clip-vit-base-patch32_TALL18.yaml +29 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +8 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_cars_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +16 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/deepseek-v2-lite.yaml +15 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral-8x7b.yaml +14 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/roberta-base_glue.yaml +69 -0
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +0 -1
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.15.dist-info → fusion_bench-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
from typing import List, Optional, Union
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from transformers.models.llama import LlamaTokenizerFast
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class DeepseekTokenizerFast(LlamaTokenizerFast):
|
|
8
|
+
|
|
9
|
+
def convert_ids_to_tokens(
|
|
10
|
+
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
|
|
11
|
+
) -> Union[str, List[str]]:
|
|
12
|
+
"""
|
|
13
|
+
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
|
|
14
|
+
added tokens.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
ids (`int` or `List[int]`):
|
|
18
|
+
The token id (or token ids) to convert to tokens.
|
|
19
|
+
skip_special_tokens (`bool`, *optional*, defaults to `False`):
|
|
20
|
+
Whether or not to remove special tokens in the decoding.
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
`str` or `List[str]`: The decoded token(s).
|
|
24
|
+
"""
|
|
25
|
+
if isinstance(ids, int):
|
|
26
|
+
return self._convert_id_to_token(ids)
|
|
27
|
+
tokens = []
|
|
28
|
+
for index in ids:
|
|
29
|
+
index = int(index)
|
|
30
|
+
if skip_special_tokens and index in self.all_special_ids:
|
|
31
|
+
continue
|
|
32
|
+
token = self._tokenizer.id_to_token(index)
|
|
33
|
+
tokens.append(token if token is not None else "")
|
|
34
|
+
return tokens
|
|
35
|
+
|
|
36
|
+
def _convert_id_to_token(self, index: int) -> Optional[str]:
|
|
37
|
+
token = self._tokenizer.id_to_token(int(index))
|
|
38
|
+
return token if token is not None else ""
|
|
@@ -66,7 +66,9 @@ class ParameterDictModel(nn.Module):
|
|
|
66
66
|
super().__init__()
|
|
67
67
|
if parameters is not None:
|
|
68
68
|
for name, param in parameters.items():
|
|
69
|
-
assert isinstance(
|
|
69
|
+
assert isinstance(
|
|
70
|
+
param, (nn.Parameter, nn.Buffer)
|
|
71
|
+
), f"{name} is not a nn.Parameter or nn.Buffer"
|
|
70
72
|
_set_attr(
|
|
71
73
|
self,
|
|
72
74
|
name.split("."),
|
|
@@ -114,3 +116,6 @@ class ParameterDictModel(nn.Module):
|
|
|
114
116
|
|
|
115
117
|
def values(self) -> List[nn.Parameter]:
|
|
116
118
|
return [self[name] for name in self.keys()]
|
|
119
|
+
|
|
120
|
+
def __len__(self):
|
|
121
|
+
return len(self.keys())
|
|
@@ -9,7 +9,7 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
9
9
|
from torch import nn
|
|
10
10
|
from tqdm.auto import tqdm
|
|
11
11
|
|
|
12
|
-
import fusion_bench.utils.
|
|
12
|
+
import fusion_bench.utils.instantiate_utils
|
|
13
13
|
from fusion_bench.method import BaseAlgorithm
|
|
14
14
|
from fusion_bench.mixins import LightningFabricMixin
|
|
15
15
|
from fusion_bench.modelpool import BaseModelPool
|
|
@@ -19,8 +19,9 @@ from fusion_bench.utils import import_object, instantiate, timeit_context
|
|
|
19
19
|
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
|
|
20
20
|
from fusion_bench.utils.json import print_json
|
|
21
21
|
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
22
|
+
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
22
23
|
|
|
23
|
-
log =
|
|
24
|
+
log = getRankZeroLogger(__name__)
|
|
24
25
|
|
|
25
26
|
|
|
26
27
|
class FabricModelFusionProgram(
|
|
@@ -66,8 +67,8 @@ class FabricModelFusionProgram(
|
|
|
66
67
|
self.merged_model_save_kwargs = merged_model_save_kwargs
|
|
67
68
|
self.fast_dev_run = fast_dev_run
|
|
68
69
|
self.seed = seed
|
|
70
|
+
fusion_bench.utils.instantiate_utils.PRINT_FUNCTION_CALL = print_function_call
|
|
69
71
|
super().__init__(**kwargs)
|
|
70
|
-
fusion_bench.utils.instantiate.PRINT_FUNCTION_CALL = print_function_call
|
|
71
72
|
|
|
72
73
|
if print_config:
|
|
73
74
|
print_config_tree(
|
|
@@ -196,6 +197,11 @@ class FabricModelFusionProgram(
|
|
|
196
197
|
for key, item in merged_model.items():
|
|
197
198
|
if isinstance(item, nn.Module):
|
|
198
199
|
report[key] = taskpool.evaluate(item, *args, **kwargs)
|
|
200
|
+
elif key == "models":
|
|
201
|
+
# for multi-model evaluation
|
|
202
|
+
report[key] = self.evaluate_merged_model(
|
|
203
|
+
taskpool, item, *args, **kwargs
|
|
204
|
+
)
|
|
199
205
|
else:
|
|
200
206
|
# metadata
|
|
201
207
|
report[key] = item
|
|
@@ -247,13 +253,16 @@ class FabricModelFusionProgram(
|
|
|
247
253
|
if self.taskpool is not None:
|
|
248
254
|
report = self.evaluate_merged_model(self.taskpool, merged_model)
|
|
249
255
|
try:
|
|
250
|
-
|
|
256
|
+
if rank_zero_only.rank == 0:
|
|
257
|
+
print_json(report, print_type=False)
|
|
251
258
|
except Exception as e:
|
|
252
259
|
log.warning(f"Failed to pretty print the report: {e}")
|
|
253
|
-
|
|
260
|
+
log.info(report)
|
|
254
261
|
if self.report_save_path is not None:
|
|
255
262
|
# save report (Dict) to a file
|
|
256
263
|
# if the directory of `save_report` does not exists, create it
|
|
264
|
+
if "{log_dir}" in self.report_save_path and self.log_dir is not None:
|
|
265
|
+
self.report_save_path = self.report_save_path.format(log_dir=self.log_dir)
|
|
257
266
|
os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
|
|
258
267
|
json.dump(report, open(self.report_save_path, "w"))
|
|
259
268
|
else:
|
|
@@ -348,8 +348,15 @@ class CLIPVisionModelTaskPool(
|
|
|
348
348
|
|
|
349
349
|
log.info(f"Evaluation Result: {report}")
|
|
350
350
|
if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
|
|
351
|
-
|
|
351
|
+
save_path = os.path.join(self.log_dir, "report.json")
|
|
352
|
+
for version in itertools.count(1):
|
|
353
|
+
if not os.path.exists(save_path):
|
|
354
|
+
break
|
|
355
|
+
# if the file already exists, increment the version to avoid overwriting
|
|
356
|
+
save_path = os.path.join(self.log_dir, f"report_{version}.json")
|
|
357
|
+
with open(save_path, "w") as fp:
|
|
352
358
|
json.dump(report, fp)
|
|
359
|
+
log.info(f"Evaluation report saved to {save_path}")
|
|
353
360
|
return report
|
|
354
361
|
|
|
355
362
|
def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
|
fusion_bench/taskpool/dummy.py
CHANGED
|
@@ -10,6 +10,7 @@ from fusion_bench.models.separate_io import separate_save
|
|
|
10
10
|
from fusion_bench.taskpool.base_pool import BaseTaskPool
|
|
11
11
|
from fusion_bench.utils import timeit_context
|
|
12
12
|
from fusion_bench.utils.parameters import count_parameters, print_parameters
|
|
13
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
13
14
|
|
|
14
15
|
|
|
15
16
|
def get_model_summary(model: nn.Module) -> dict:
|
|
@@ -49,10 +50,11 @@ class DummyTaskPool(BaseTaskPool):
|
|
|
49
50
|
Args:
|
|
50
51
|
model: The model to evaluate.
|
|
51
52
|
"""
|
|
52
|
-
|
|
53
|
+
if rank_zero_only.rank == 0:
|
|
54
|
+
print_parameters(model, is_human_readable=True)
|
|
53
55
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
56
|
+
if self.model_save_path is not None:
|
|
57
|
+
with timeit_context(f"Saving the model to {self.model_save_path}"):
|
|
58
|
+
separate_save(model, self.model_save_path)
|
|
57
59
|
|
|
58
60
|
return get_model_summary(model)
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -7,8 +7,9 @@ from .cache_utils import *
|
|
|
7
7
|
from .devices import *
|
|
8
8
|
from .dtype import parse_dtype
|
|
9
9
|
from .fabric import seed_everything_by_time
|
|
10
|
-
from .
|
|
10
|
+
from .instantiate_utils import instantiate, is_instantiable
|
|
11
11
|
from .misc import *
|
|
12
12
|
from .packages import import_object
|
|
13
13
|
from .parameters import *
|
|
14
14
|
from .timer import timeit_context
|
|
15
|
+
from .lazy_state_dict import LazyStateDict
|
fusion_bench/utils/data.py
CHANGED
|
@@ -96,7 +96,7 @@ def train_validation_split(
|
|
|
96
96
|
|
|
97
97
|
# Compute the number of samples for training and validation
|
|
98
98
|
num_samples = len(dataset)
|
|
99
|
-
if validation_size is
|
|
99
|
+
if validation_size is None:
|
|
100
100
|
assert (
|
|
101
101
|
0 < validation_fraction < 1
|
|
102
102
|
), "Validation fraction must be between 0 and 1"
|
|
@@ -41,6 +41,9 @@ def set_print_function_call(value: bool):
|
|
|
41
41
|
finally:
|
|
42
42
|
PRINT_FUNCTION_CALL = old_value
|
|
43
43
|
|
|
44
|
+
def set_print_function_call_permeanent(value: bool):
|
|
45
|
+
global PRINT_FUNCTION_CALL
|
|
46
|
+
PRINT_FUNCTION_CALL = value
|
|
44
47
|
|
|
45
48
|
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
46
49
|
if OmegaConf.is_dict(config):
|
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from accelerate.utils.constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
|
8
|
+
from huggingface_hub import snapshot_download
|
|
9
|
+
from safetensors import safe_open
|
|
10
|
+
from safetensors.torch import load_file
|
|
11
|
+
from transformers import AutoConfig
|
|
12
|
+
|
|
13
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from transformers import PretrainedConfig
|
|
17
|
+
|
|
18
|
+
log = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
__all__ = ["resolve_checkpoint_path", "LazyStateDict"]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def resolve_checkpoint_path(
|
|
24
|
+
checkpoint: str,
|
|
25
|
+
hf_revision: Optional[str] = None,
|
|
26
|
+
hf_cache_dir: Optional[str] = None,
|
|
27
|
+
hf_proxies: Optional[Dict] = None,
|
|
28
|
+
):
|
|
29
|
+
# If it's a local file or directory, return as is
|
|
30
|
+
if os.path.exists(checkpoint):
|
|
31
|
+
return checkpoint
|
|
32
|
+
# If it's a HuggingFace Hub model id, download snapshot
|
|
33
|
+
try:
|
|
34
|
+
# This will download the model to the cache and return the local path
|
|
35
|
+
local_path = snapshot_download(
|
|
36
|
+
repo_id=checkpoint,
|
|
37
|
+
revision=hf_revision,
|
|
38
|
+
cache_dir=hf_cache_dir,
|
|
39
|
+
proxies=hf_proxies,
|
|
40
|
+
)
|
|
41
|
+
return local_path
|
|
42
|
+
except Exception as e:
|
|
43
|
+
raise FileNotFoundError(
|
|
44
|
+
f"Could not resolve checkpoint: {checkpoint}. Error: {e}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class LazyStateDict:
|
|
49
|
+
"""
|
|
50
|
+
Dictionary-like object that lazily loads a state dict from a checkpoint path.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
_local_path: str
|
|
54
|
+
_state_dict_cache: Optional[Dict]
|
|
55
|
+
_index_filename: Optional[str]
|
|
56
|
+
_checkpoint_files: Optional[List[str]]
|
|
57
|
+
_index: Optional[Dict]
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
checkpoint: str,
|
|
62
|
+
cache_state_dict: bool = False,
|
|
63
|
+
torch_dtype: Optional[torch.dtype] = None,
|
|
64
|
+
device: str = "cpu",
|
|
65
|
+
hf_revision: Optional[str] = None,
|
|
66
|
+
hf_cache_dir: Optional[str] = None,
|
|
67
|
+
hf_proxies: Optional[Dict] = None,
|
|
68
|
+
):
|
|
69
|
+
self._checkpoint = checkpoint
|
|
70
|
+
self._local_path = resolve_checkpoint_path(
|
|
71
|
+
checkpoint,
|
|
72
|
+
hf_revision=hf_revision,
|
|
73
|
+
hf_cache_dir=hf_cache_dir,
|
|
74
|
+
hf_proxies=hf_proxies,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
self._index, self._index_filename, self._checkpoint_files = (
|
|
78
|
+
self._resolve_checkpoint_files(self._local_path)
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
if cache_state_dict:
|
|
82
|
+
self._state_dict_cache = {}
|
|
83
|
+
else:
|
|
84
|
+
self._state_dict_cache = None
|
|
85
|
+
|
|
86
|
+
self._torch_dtype = parse_dtype(torch_dtype)
|
|
87
|
+
self._device = device
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def checkpoint(self) -> str:
|
|
91
|
+
return self._checkpoint
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def config(self) -> "PretrainedConfig":
|
|
95
|
+
return AutoConfig.from_pretrained(self._checkpoint)
|
|
96
|
+
|
|
97
|
+
def state_dict(self) -> "LazyStateDict":
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
def _resolve_checkpoint_files(self, checkpoint: str):
|
|
101
|
+
# reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
|
|
102
|
+
checkpoint_files = None
|
|
103
|
+
index_filename = None
|
|
104
|
+
if os.path.isfile(checkpoint):
|
|
105
|
+
if str(checkpoint).endswith(".json"):
|
|
106
|
+
index_filename = checkpoint
|
|
107
|
+
else:
|
|
108
|
+
checkpoint_files = [checkpoint]
|
|
109
|
+
elif os.path.isdir(checkpoint):
|
|
110
|
+
# check if the whole state dict is present
|
|
111
|
+
potential_state_bin = [
|
|
112
|
+
f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME
|
|
113
|
+
]
|
|
114
|
+
potential_state_safetensor = [
|
|
115
|
+
f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME
|
|
116
|
+
]
|
|
117
|
+
if len(potential_state_bin) == 1:
|
|
118
|
+
checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
|
|
119
|
+
elif len(potential_state_safetensor) == 1:
|
|
120
|
+
checkpoint_files = [
|
|
121
|
+
os.path.join(checkpoint, potential_state_safetensor[0])
|
|
122
|
+
]
|
|
123
|
+
else:
|
|
124
|
+
# otherwise check for sharded checkpoints
|
|
125
|
+
potential_index = [
|
|
126
|
+
f for f in os.listdir(checkpoint) if f.endswith(".index.json")
|
|
127
|
+
]
|
|
128
|
+
if len(potential_index) == 0:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
|
|
131
|
+
)
|
|
132
|
+
elif len(potential_index) == 1:
|
|
133
|
+
index_filename = os.path.join(checkpoint, potential_index[0])
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
raise ValueError(
|
|
140
|
+
"`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
|
|
141
|
+
f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
if index_filename is not None:
|
|
145
|
+
checkpoint_folder = os.path.split(index_filename)[0]
|
|
146
|
+
with open(index_filename) as f:
|
|
147
|
+
index = json.loads(f.read())
|
|
148
|
+
|
|
149
|
+
if "weight_map" in index:
|
|
150
|
+
index = index["weight_map"]
|
|
151
|
+
checkpoint_files = sorted(list(set(index.values())))
|
|
152
|
+
checkpoint_files = [
|
|
153
|
+
os.path.join(checkpoint_folder, f) for f in checkpoint_files
|
|
154
|
+
]
|
|
155
|
+
return index, index_filename, checkpoint_files
|
|
156
|
+
|
|
157
|
+
def _load_tensor_from_checkpoint_file(
|
|
158
|
+
self, checkpoint_file: str, key: str, update_cache: bool = True
|
|
159
|
+
) -> torch.Tensor:
|
|
160
|
+
if checkpoint_file.endswith(".safetensors"):
|
|
161
|
+
with safe_open(checkpoint_file, framework="pt", device=self._device) as f:
|
|
162
|
+
tensor = f.get_tensor(key)
|
|
163
|
+
if self._torch_dtype is not None:
|
|
164
|
+
tensor = tensor.to(self._torch_dtype)
|
|
165
|
+
if update_cache and self._state_dict_cache is not None:
|
|
166
|
+
self._state_dict_cache[key] = tensor
|
|
167
|
+
return tensor
|
|
168
|
+
else:
|
|
169
|
+
state_dict = torch.load(checkpoint_file, map_location=self._device)
|
|
170
|
+
if update_cache:
|
|
171
|
+
if self._state_dict_cache is not None:
|
|
172
|
+
self._state_dict_cache.update(state_dict)
|
|
173
|
+
else:
|
|
174
|
+
log.warning(
|
|
175
|
+
f"Load full state dict from file {checkpoint_file}, but state dict cache is disabled."
|
|
176
|
+
)
|
|
177
|
+
return state_dict[key]
|
|
178
|
+
|
|
179
|
+
def __getitem__(self, key: str) -> torch.Tensor:
|
|
180
|
+
if self._state_dict_cache is not None and key in self._state_dict_cache:
|
|
181
|
+
return self._state_dict_cache[key]
|
|
182
|
+
|
|
183
|
+
if self._index is None:
|
|
184
|
+
if len(self._checkpoint_files) == 1 and os.path.isfile(
|
|
185
|
+
self._checkpoint_files[0]
|
|
186
|
+
):
|
|
187
|
+
checkpoint_file = self._checkpoint_files[0]
|
|
188
|
+
tensor = self._load_tensor_from_checkpoint_file(
|
|
189
|
+
checkpoint_file, key, update_cache=True
|
|
190
|
+
)
|
|
191
|
+
return tensor
|
|
192
|
+
else:
|
|
193
|
+
if len(self._checkpoint_files) > 1:
|
|
194
|
+
raise RuntimeError(
|
|
195
|
+
"Get multiple checkpoint files, but index is not provided."
|
|
196
|
+
)
|
|
197
|
+
if not os.path.isfile(self._checkpoint_files[0]):
|
|
198
|
+
raise FileNotFoundError(
|
|
199
|
+
f"Checkpoint file {self._checkpoint_files[0]} not found."
|
|
200
|
+
)
|
|
201
|
+
raise RuntimeError("Unexpected error.")
|
|
202
|
+
else:
|
|
203
|
+
if key not in self._index:
|
|
204
|
+
raise KeyError(f"Key {key} not found in index.")
|
|
205
|
+
checkpoint_file = os.path.join(self._local_path, self._index[key])
|
|
206
|
+
if not os.path.isfile(checkpoint_file):
|
|
207
|
+
raise FileNotFoundError(f"Checkpoint file {checkpoint_file} not found.")
|
|
208
|
+
tensor = self._load_tensor_from_checkpoint_file(
|
|
209
|
+
checkpoint_file, key, update_cache=True
|
|
210
|
+
)
|
|
211
|
+
return tensor
|
|
212
|
+
|
|
213
|
+
def __contains__(self, key: str) -> bool:
|
|
214
|
+
if self._state_dict_cache is not None and key in self._state_dict_cache:
|
|
215
|
+
return True
|
|
216
|
+
if self._index is not None and key in self._index:
|
|
217
|
+
return True
|
|
218
|
+
if len(self._checkpoint_files) == 1 and os.path.isfile(
|
|
219
|
+
self._checkpoint_files[0]
|
|
220
|
+
):
|
|
221
|
+
try:
|
|
222
|
+
tensor = self._load_tensor_from_checkpoint_file(
|
|
223
|
+
self._checkpoint_files[0], key, update_cache=False
|
|
224
|
+
)
|
|
225
|
+
return tensor is not None
|
|
226
|
+
except Exception:
|
|
227
|
+
return False
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
def __len__(self) -> int:
|
|
231
|
+
if self._index is not None:
|
|
232
|
+
return len(self._index)
|
|
233
|
+
if len(self._checkpoint_files) == 1 and os.path.isfile(
|
|
234
|
+
self._checkpoint_files[0]
|
|
235
|
+
):
|
|
236
|
+
checkpoint_file = self._checkpoint_files[0]
|
|
237
|
+
if checkpoint_file.endswith(".safetensors"):
|
|
238
|
+
with safe_open(checkpoint_file, framework="pt", device="cpu") as f:
|
|
239
|
+
return len(tuple(f.keys()))
|
|
240
|
+
else:
|
|
241
|
+
return len(
|
|
242
|
+
tuple(torch.load(checkpoint_file, map_location="cpu").keys())
|
|
243
|
+
)
|
|
244
|
+
raise RuntimeError(
|
|
245
|
+
"Unexpected error: cannot determine the number of keys in the state dict."
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
def __iter__(self) -> Iterator[str]:
|
|
249
|
+
if self._index is not None:
|
|
250
|
+
return iter(self._index)
|
|
251
|
+
return iter(self._checkpoint_files)
|
|
252
|
+
|
|
253
|
+
def keys(self) -> List[str]:
|
|
254
|
+
return list(self)
|
|
255
|
+
|
|
256
|
+
def values(self) -> List[torch.Tensor]:
|
|
257
|
+
return [self[key] for key in self]
|
|
258
|
+
|
|
259
|
+
def items(self) -> Iterator[Tuple[str, torch.Tensor]]:
|
|
260
|
+
return ((key, self[key]) for key in self)
|
|
261
|
+
|
|
262
|
+
def __repr__(self) -> str:
|
|
263
|
+
if self._index is not None:
|
|
264
|
+
return f"{self.__class__.__name__}(index={self._index})"
|
|
265
|
+
else:
|
|
266
|
+
return (
|
|
267
|
+
f"{self.__class__.__name__}(checkpoint_files={self._checkpoint_files})"
|
|
268
|
+
)
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -222,6 +222,39 @@ def count_parameters(module: nn.Module, non_zero_only: bool = False) -> tuple[in
|
|
|
222
222
|
return trainable_params, all_param
|
|
223
223
|
|
|
224
224
|
|
|
225
|
+
@torch.no_grad()
|
|
226
|
+
def get_parameter_summary(
|
|
227
|
+
module_or_state_dict: Union[nn.Module, StateDictType], non_zero_only: bool = False
|
|
228
|
+
) -> dict:
|
|
229
|
+
"""
|
|
230
|
+
Get a summary of the parameters in a PyTorch model.
|
|
231
|
+
"""
|
|
232
|
+
if isinstance(module_or_state_dict, nn.Module):
|
|
233
|
+
state_dict = module_or_state_dict.state_dict(keep_vars=True)
|
|
234
|
+
else:
|
|
235
|
+
state_dict = module_or_state_dict
|
|
236
|
+
|
|
237
|
+
trainable_params = 0
|
|
238
|
+
all_param = 0
|
|
239
|
+
bytes = 0
|
|
240
|
+
|
|
241
|
+
for name, param in state_dict.items():
|
|
242
|
+
# count the number of parameters
|
|
243
|
+
num_params = _numel(param, non_zero_only)
|
|
244
|
+
bytes += _numel(param, non_zero_only=False) * param.element_size()
|
|
245
|
+
|
|
246
|
+
# accumulate the number of trainable and total parameters
|
|
247
|
+
all_param += num_params
|
|
248
|
+
if param.requires_grad:
|
|
249
|
+
trainable_params += num_params
|
|
250
|
+
|
|
251
|
+
return {
|
|
252
|
+
"trainable_params": trainable_params,
|
|
253
|
+
"all_param": all_param,
|
|
254
|
+
"bytes": bytes,
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
|
|
225
258
|
def print_parameters(
|
|
226
259
|
module: nn.Module,
|
|
227
260
|
is_human_readable: bool = True,
|
fusion_bench/utils/pylogger.py
CHANGED
|
@@ -53,3 +53,31 @@ class RankedLogger(logging.LoggerAdapter):
|
|
|
53
53
|
self.logger.log(level, msg, *args, **kwargs)
|
|
54
54
|
elif current_rank == rank:
|
|
55
55
|
self.logger.log(level, msg, *args, **kwargs)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class RankZeroLogger(logging.Logger):
|
|
59
|
+
"""A logger that logs only on rank zero and works just like logging.Logger"""
|
|
60
|
+
|
|
61
|
+
@rank_zero_only
|
|
62
|
+
def _log(self, *args, **kwargs):
|
|
63
|
+
if "stacklevel" in kwargs:
|
|
64
|
+
kwargs["stacklevel"] += 1
|
|
65
|
+
return super()._log(*args, **kwargs)
|
|
66
|
+
|
|
67
|
+
def is_global_zero(self):
|
|
68
|
+
return rank_zero_only.rank == 0
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
RankZeroLogger.manager = logging.Manager(RankZeroLogger.root)
|
|
72
|
+
RankZeroLogger.manager.setLoggerClass(RankZeroLogger)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def getRankZeroLogger(name=None):
|
|
76
|
+
"""
|
|
77
|
+
Return a logger with the specified name, creating it if necessary.
|
|
78
|
+
|
|
79
|
+
If no name is specified, return the root logger.
|
|
80
|
+
"""
|
|
81
|
+
if not name or isinstance(name, str) and name == logging.root.name:
|
|
82
|
+
return logging.root
|
|
83
|
+
return RankZeroLogger.manager.getLogger(name)
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
from collections import OrderedDict
|
|
2
2
|
from numbers import Number
|
|
3
|
-
from typing import Dict, List, Union, cast
|
|
3
|
+
from typing import Callable, Dict, List, Literal, Union, cast
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import Tensor
|
|
7
7
|
|
|
8
8
|
from .parameters import check_parameters_all_equal
|
|
9
|
-
from .type import StateDictType
|
|
9
|
+
from .type import BoolStateDictType, StateDictType
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def to_device(
|
|
@@ -295,3 +295,75 @@ def state_dict_weighted_sum(
|
|
|
295
295
|
device, non_blocking=True
|
|
296
296
|
)
|
|
297
297
|
return weighted_sum_state_dict
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def state_dict_diff_abs(a: StateDictType, b: StateDictType):
|
|
301
|
+
"""
|
|
302
|
+
Returns the per-layer abs of the difference between two state dicts.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
a (StateDictType): The first state dict.
|
|
306
|
+
b (StateDictType): The second state dict.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
StateDictType: The absolute difference between the two state dicts.
|
|
310
|
+
"""
|
|
311
|
+
diff = state_dict_sub(a, b)
|
|
312
|
+
abs_diff = {key: diff[key].abs() for key in diff}
|
|
313
|
+
return abs_diff
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def state_dict_binary_mask(
|
|
317
|
+
a: StateDictType,
|
|
318
|
+
b: StateDictType,
|
|
319
|
+
compare_fn: Union[
|
|
320
|
+
Literal["greater", "less", "equal", "not_equal"],
|
|
321
|
+
Callable[[Tensor, Tensor], torch.BoolTensor],
|
|
322
|
+
] = "greater",
|
|
323
|
+
) -> BoolStateDictType:
|
|
324
|
+
"""
|
|
325
|
+
Returns the binary mask of elements in a compared to elements in b using the provided comparison function.
|
|
326
|
+
|
|
327
|
+
Args:
|
|
328
|
+
a (StateDictType): The first state dict.
|
|
329
|
+
b (StateDictType): The second state dict.
|
|
330
|
+
compare_fn (Union[Literal["greater", "less", "equal", "not_equal"], Callable[[Tensor, Tensor], Tensor]]): A function that takes two tensors and returns a boolean tensor.
|
|
331
|
+
Defaults to greater than comparison (x > y).
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
StateDictType: A dictionary containing binary masks (0 or 1) based on the comparison.
|
|
335
|
+
"""
|
|
336
|
+
compare_fn_dict = {
|
|
337
|
+
"greater": lambda x, y: x > y,
|
|
338
|
+
"less": lambda x, y: x < y,
|
|
339
|
+
"equal": lambda x, y: x == y,
|
|
340
|
+
"not_equal": lambda x, y: x != y,
|
|
341
|
+
}
|
|
342
|
+
if isinstance(compare_fn, str):
|
|
343
|
+
compare_fn = compare_fn_dict[compare_fn]
|
|
344
|
+
elif not callable(compare_fn):
|
|
345
|
+
raise ValueError(
|
|
346
|
+
f"compare_fn must be a string or a callable, but got {type(compare_fn)}"
|
|
347
|
+
)
|
|
348
|
+
|
|
349
|
+
mask = OrderedDict()
|
|
350
|
+
for key in a:
|
|
351
|
+
mask[key] = compare_fn(a[key], b[key])
|
|
352
|
+
return mask
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def state_dict_hadmard_product(a: StateDictType, b: StateDictType) -> StateDictType:
|
|
356
|
+
"""
|
|
357
|
+
Returns the Hadamard product of two state dicts, i.e. element-wise product.
|
|
358
|
+
|
|
359
|
+
Args:
|
|
360
|
+
a (StateDictType): The first state dict.
|
|
361
|
+
b (StateDictType): The second state dict.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
StateDictType: The Hadamard product of the two state dicts.
|
|
365
|
+
"""
|
|
366
|
+
ans = OrderedDict()
|
|
367
|
+
for key in a:
|
|
368
|
+
ans[key] = a[key] * b[key]
|
|
369
|
+
return ans
|
fusion_bench/utils/type.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.17
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -70,7 +70,7 @@ Dynamic: license-file
|
|
|
70
70
|
|
|
71
71
|
FusionBench is a benchmark suite designed to evaluate the performance of various deep model fusion techniques. It aims to provide a comprehensive comparison of different methods on a variety of datasets and tasks.
|
|
72
72
|
|
|
73
|
-
Projects based on FusionBench and news from the community (descending order of date):
|
|
73
|
+
Projects based on FusionBench and news from the community (descending order of date. If you have any work based on FusionBench, please feel free to let us know, we are willing to add it to the list. :partying_face:):
|
|
74
74
|
|
|
75
75
|
<details>
|
|
76
76
|
<summary>Hao Mark Chen, et al. FW-Merging: Scaling Model Merging with Frank-Wolfe Optimization. Mar 2025. https://arxiv.org/abs/2503.12649</summary>
|
|
@@ -139,6 +139,10 @@ cd fusion_bench
|
|
|
139
139
|
pip install -e . # install the package in editable mode
|
|
140
140
|
```
|
|
141
141
|
|
|
142
|
+
> [!TIP]
|
|
143
|
+
> FusionBench is highly dependent on the use of [Hydra](https://hydra.cc/) for configuration management and command line argument parsing, and [Lightning Fabric](https://lightning.ai/) for device management.
|
|
144
|
+
> If you are not familiar with these tools, it is strongly recommended to read the [Hydra](https://hydra.cc/docs/intro/) and [Lightning Fabric](https://lightning.ai/docs/fabric/stable/) documentation.
|
|
145
|
+
|
|
142
146
|
### Install with [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)
|
|
143
147
|
|
|
144
148
|
[](https://doi.org/10.5281/zenodo.10256836)
|
|
@@ -167,6 +171,8 @@ It can be used to improve the performance and robustness of model or to combine
|
|
|
167
171
|
For a more detailed introduction to deep model fusion, you can refer to [W. Li, 2023, 'Deep Model Fusion: A Survey'](https://arxiv.org/abs/2309.15698). We also provide a brief overview of deep model fusion in [our documentation](https://tanganke.github.io/fusion_bench/).
|
|
168
172
|
In this benchmark, we evaluate the performance of different fusion methods on a variety of datasets and tasks.
|
|
169
173
|
|
|
174
|
+
A comprehensive list of papers about model merging can be found at [this repository](https://github.com/EnnengYang/Awesome-Model-Merging-Methods-Theories-Applications), and [the arXiv paper](https://arxiv.org/abs/2408.07666) is also available.
|
|
175
|
+
|
|
170
176
|
## Project Structure
|
|
171
177
|
|
|
172
178
|
The project is structured as follows:
|