fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -56,6 +56,8 @@ class LayerWiseFeatureSaver:
|
|
|
56
56
|
first_token_only: bool = True,
|
|
57
57
|
max_num: Optional[int] = None,
|
|
58
58
|
):
|
|
59
|
+
if isinstance(save_path, str):
|
|
60
|
+
save_path = Path(save_path)
|
|
59
61
|
self.save_path = save_path
|
|
60
62
|
self.first_token_only = first_token_only
|
|
61
63
|
self.max_num = max_num
|
|
@@ -122,9 +124,9 @@ class CLIPVisionModelTaskPool(
|
|
|
122
124
|
self,
|
|
123
125
|
test_datasets: Union[DictConfig, Dict[str, Dataset]],
|
|
124
126
|
*,
|
|
125
|
-
processor: Union[DictConfig, CLIPProcessor],
|
|
126
|
-
|
|
127
|
-
|
|
127
|
+
processor: Union[str, DictConfig, CLIPProcessor],
|
|
128
|
+
clip_model: Union[str, DictConfig, CLIPModel],
|
|
129
|
+
data_processor: Union[DictConfig, CLIPProcessor] = None,
|
|
128
130
|
dataloader_kwargs: DictConfig = None,
|
|
129
131
|
layer_wise_feature_save_path: Optional[str] = None,
|
|
130
132
|
layer_wise_feature_first_token_only: bool = True,
|
|
@@ -159,21 +161,35 @@ class CLIPVisionModelTaskPool(
|
|
|
159
161
|
Set up the processor, data processor, CLIP model, test datasets, and data loaders.
|
|
160
162
|
"""
|
|
161
163
|
# setup processor and clip model
|
|
162
|
-
self.
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
instantiate(self.
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
self.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
164
|
+
if isinstance(self._processor, str):
|
|
165
|
+
self.processor = CLIPProcessor.from_pretrained(self._processor)
|
|
166
|
+
elif (
|
|
167
|
+
isinstance(self._processor, (dict, DictConfig))
|
|
168
|
+
and "_target_" in self._processor
|
|
169
|
+
):
|
|
170
|
+
self.processor = instantiate(self._processor)
|
|
171
|
+
else:
|
|
172
|
+
self.processor = self._processor
|
|
173
|
+
|
|
174
|
+
if self._data_processor is None:
|
|
175
|
+
self.data_processor = self.processor
|
|
176
|
+
else:
|
|
177
|
+
self.data_processor = (
|
|
178
|
+
instantiate(self._data_processor)
|
|
179
|
+
if isinstance(self._data_processor, DictConfig)
|
|
180
|
+
else self._data_processor
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
if isinstance(self._clip_model, str):
|
|
184
|
+
self.clip_model = CLIPModel.from_pretrained(self._clip_model)
|
|
185
|
+
elif (
|
|
186
|
+
isinstance(self._clip_model, (dict, DictConfig))
|
|
187
|
+
and "_target_" in self._clip_model
|
|
188
|
+
):
|
|
189
|
+
self.clip_model = instantiate(self._clip_model)
|
|
190
|
+
else:
|
|
191
|
+
self.clip_model = self._clip_model
|
|
192
|
+
|
|
177
193
|
self.clip_model = self.fabric.to_device(self.clip_model)
|
|
178
194
|
self.clip_model.requires_grad_(False)
|
|
179
195
|
self.clip_model.eval()
|
fusion_bench/taskpool/dummy.py
CHANGED
|
@@ -4,13 +4,13 @@ This is the dummy task pool that is used for debugging purposes.
|
|
|
4
4
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
|
|
7
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
7
8
|
from torch import nn
|
|
8
9
|
|
|
9
10
|
from fusion_bench.models.separate_io import separate_save
|
|
10
11
|
from fusion_bench.taskpool.base_pool import BaseTaskPool
|
|
11
12
|
from fusion_bench.utils import timeit_context
|
|
12
13
|
from fusion_bench.utils.parameters import count_parameters, print_parameters
|
|
13
|
-
from lightning.pytorch.utilities import rank_zero_only
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def get_model_summary(model: nn.Module) -> dict:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
from typing import List, Literal, Optional, Union
|
|
3
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
import lightning.fabric
|
|
6
6
|
import lm_eval
|
|
@@ -12,7 +12,6 @@ from fusion_bench import BaseTaskPool
|
|
|
12
12
|
from fusion_bench.mixins import LightningFabricMixin
|
|
13
13
|
from fusion_bench.utils.strenum import _version
|
|
14
14
|
|
|
15
|
-
|
|
16
15
|
log = logging.getLogger(__name__)
|
|
17
16
|
|
|
18
17
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Any, Callable, Dict, List
|
|
3
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
6
|
|
|
@@ -79,7 +79,9 @@ class CLIPTemplateFactory:
|
|
|
79
79
|
}
|
|
80
80
|
|
|
81
81
|
@staticmethod
|
|
82
|
-
def get_classnames_and_templates(
|
|
82
|
+
def get_classnames_and_templates(
|
|
83
|
+
dataset_name: str,
|
|
84
|
+
) -> Tuple[List[str], List[Callable]]:
|
|
83
85
|
"""
|
|
84
86
|
Retrieves class names and templates for the specified dataset.
|
|
85
87
|
|
|
@@ -169,7 +171,7 @@ class CLIPTemplateFactory:
|
|
|
169
171
|
CLIPTemplateFactory._dataset_mapping[dataset_name] = dataset_info
|
|
170
172
|
|
|
171
173
|
@staticmethod
|
|
172
|
-
def get_available_datasets():
|
|
174
|
+
def get_available_datasets() -> List[str]:
|
|
173
175
|
"""
|
|
174
176
|
Get a list of all available dataset names.
|
|
175
177
|
|
|
@@ -179,5 +181,5 @@ class CLIPTemplateFactory:
|
|
|
179
181
|
return list(CLIPTemplateFactory._dataset_mapping.keys())
|
|
180
182
|
|
|
181
183
|
|
|
182
|
-
def get_classnames_and_templates(dataset_name: str):
|
|
184
|
+
def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
|
|
183
185
|
return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -7,7 +7,12 @@ from .cache_utils import *
|
|
|
7
7
|
from .devices import *
|
|
8
8
|
from .dtype import parse_dtype
|
|
9
9
|
from .fabric import seed_everything_by_time
|
|
10
|
-
from .instantiate_utils import
|
|
10
|
+
from .instantiate_utils import (
|
|
11
|
+
instantiate,
|
|
12
|
+
is_instantiable,
|
|
13
|
+
set_print_function_call,
|
|
14
|
+
set_print_function_call_permeanent,
|
|
15
|
+
)
|
|
11
16
|
from .json import load_from_json, save_to_json
|
|
12
17
|
from .lazy_state_dict import LazyStateDict
|
|
13
18
|
from .misc import *
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import List, Optional, Union
|
|
4
5
|
|
|
@@ -12,7 +13,7 @@ from transformers.utils import (
|
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
15
|
-
"
|
|
16
|
+
"clear_cuda_cache",
|
|
16
17
|
"to_device",
|
|
17
18
|
"num_devices",
|
|
18
19
|
"get_device",
|
|
@@ -21,10 +22,19 @@ __all__ = [
|
|
|
21
22
|
"get_device_capabilities",
|
|
22
23
|
]
|
|
23
24
|
|
|
25
|
+
log = logging.getLogger(__name__)
|
|
24
26
|
|
|
25
|
-
|
|
27
|
+
|
|
28
|
+
def clear_cuda_cache():
|
|
29
|
+
"""
|
|
30
|
+
Clears the CUDA memory cache to free up GPU memory.
|
|
31
|
+
Works only if CUDA is available.
|
|
32
|
+
"""
|
|
26
33
|
gc.collect()
|
|
27
|
-
torch.cuda.
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
torch.cuda.empty_cache()
|
|
36
|
+
else:
|
|
37
|
+
log.warning("CUDA is not available. No cache to clear.")
|
|
28
38
|
|
|
29
39
|
|
|
30
40
|
def to_device(obj, device: Optional[torch.device], **kwargs):
|
|
@@ -75,7 +85,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
|
|
|
75
85
|
Return the number of devices.
|
|
76
86
|
|
|
77
87
|
Args:
|
|
78
|
-
devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3]
|
|
88
|
+
devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
|
|
79
89
|
|
|
80
90
|
Returns:
|
|
81
91
|
The number of devices.
|
|
@@ -28,7 +28,7 @@ PRINT_FUNCTION_CALL_FUNC = print
|
|
|
28
28
|
Function to be used for printing function calls.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
CATCH_EXCEPTION =
|
|
31
|
+
CATCH_EXCEPTION = False
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
@contextmanager
|
|
@@ -41,10 +41,12 @@ def set_print_function_call(value: bool):
|
|
|
41
41
|
finally:
|
|
42
42
|
PRINT_FUNCTION_CALL = old_value
|
|
43
43
|
|
|
44
|
+
|
|
44
45
|
def set_print_function_call_permeanent(value: bool):
|
|
45
46
|
global PRINT_FUNCTION_CALL
|
|
46
47
|
PRINT_FUNCTION_CALL = value
|
|
47
48
|
|
|
49
|
+
|
|
48
50
|
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
49
51
|
if OmegaConf.is_dict(config):
|
|
50
52
|
return "_target_" in config
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
|
1
|
-
from
|
|
1
|
+
from difflib import get_close_matches
|
|
2
|
+
from typing import Any, Iterable, List, Optional
|
|
2
3
|
|
|
3
|
-
__all__ = [
|
|
4
|
+
__all__ = [
|
|
5
|
+
"first",
|
|
6
|
+
"has_length",
|
|
7
|
+
"join_list",
|
|
8
|
+
"attr_equal",
|
|
9
|
+
"validate_and_suggest_corrections",
|
|
10
|
+
]
|
|
4
11
|
|
|
5
12
|
|
|
6
13
|
def first(iterable: Iterable):
|
|
@@ -41,3 +48,42 @@ def attr_equal(obj, attr: str, value):
|
|
|
41
48
|
if not hasattr(obj, attr):
|
|
42
49
|
return False
|
|
43
50
|
return getattr(obj, attr) == value
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def validate_and_suggest_corrections(
|
|
54
|
+
obj: Any, values: Iterable[Any], *, max_suggestions: int = 3, cutoff: float = 0.6
|
|
55
|
+
) -> Any:
|
|
56
|
+
"""
|
|
57
|
+
Return *obj* if it is contained in *values*.
|
|
58
|
+
Otherwise raise a helpful ``ValueError`` that lists the closest matches.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
obj : Any
|
|
62
|
+
The value to validate.
|
|
63
|
+
values : Iterable[Any]
|
|
64
|
+
The set of allowed values.
|
|
65
|
+
max_suggestions : int, optional
|
|
66
|
+
How many typo-hints to include at most (default 3).
|
|
67
|
+
cutoff : float, optional
|
|
68
|
+
Similarity threshold for suggestions (0.0-1.0, default 0.6).
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
The original *obj* if it is valid.
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
ValueError: With a friendly message that points out possible typos.
|
|
75
|
+
"""
|
|
76
|
+
# Normalise to a list so we can reuse it
|
|
77
|
+
value_list = list(values)
|
|
78
|
+
|
|
79
|
+
if obj in value_list:
|
|
80
|
+
return obj
|
|
81
|
+
|
|
82
|
+
# Build suggestions
|
|
83
|
+
str_values = list(map(str, value_list))
|
|
84
|
+
matches = get_close_matches(str(obj), str_values, n=max_suggestions, cutoff=cutoff)
|
|
85
|
+
|
|
86
|
+
msg = f"Invalid value {obj!r}. Allowed values: {value_list}"
|
|
87
|
+
if matches:
|
|
88
|
+
msg += f". Did you mean {', '.join(repr(m) for m in matches)}?"
|
|
89
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal, Optional
|
|
3
|
+
|
|
4
|
+
from datasets import load_dataset as datasets_load_dataset
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils import validate_and_suggest_corrections
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
from modelscope import dataset_file_download as modelscope_dataset_file_download
|
|
10
|
+
from modelscope import model_file_download as modelscope_model_file_download
|
|
11
|
+
from modelscope import snapshot_download as modelscope_snapshot_download
|
|
12
|
+
|
|
13
|
+
except ImportError:
|
|
14
|
+
|
|
15
|
+
def _raise_modelscope_not_installed_error(*args, **kwargs):
|
|
16
|
+
raise ImportError(
|
|
17
|
+
"ModelScope is not installed. Please install it using `pip install modelscope` to use ModelScope models."
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
modelscope_snapshot_download = _raise_modelscope_not_installed_error
|
|
21
|
+
modelscope_model_file_download = _raise_modelscope_not_installed_error
|
|
22
|
+
modelscope_dataset_file_download = _raise_modelscope_not_installed_error
|
|
23
|
+
|
|
24
|
+
try:
|
|
25
|
+
from huggingface_hub import hf_hub_download
|
|
26
|
+
from huggingface_hub import snapshot_download as huggingface_snapshot_download
|
|
27
|
+
except ImportError:
|
|
28
|
+
|
|
29
|
+
def _raise_hugggingface_not_installed_error(*args, **kwargs):
|
|
30
|
+
raise ImportError(
|
|
31
|
+
"Hugging Face Hub is not installed. Please install it using `pip install huggingface_hub` to use Hugging Face models."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
huggingface_snapshot_download = _raise_hugggingface_not_installed_error
|
|
35
|
+
hf_hub_download = _raise_hugggingface_not_installed_error
|
|
36
|
+
|
|
37
|
+
__all__ = [
|
|
38
|
+
"load_dataset",
|
|
39
|
+
"resolve_repo_path",
|
|
40
|
+
]
|
|
41
|
+
|
|
42
|
+
AVAILABLE_PLATFORMS = ["hf", "huggingface", "modelscope"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _raise_unknown_platform_error(platform: str):
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Unsupported platform: {platform}. Supported platforms are 'hf', 'huggingface', and 'modelscope'."
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def load_dataset(
|
|
52
|
+
name: str,
|
|
53
|
+
split: str = "train",
|
|
54
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Load a dataset from Hugging Face or ModelScope.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
platform (Literal['hf', 'modelscope']): The platform to load the dataset from.
|
|
61
|
+
name (str): The name of the dataset.
|
|
62
|
+
split (str): The split of the dataset to load (default is "train").
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Dataset: The loaded dataset.
|
|
66
|
+
"""
|
|
67
|
+
validate_and_suggest_corrections(platform, AVAILABLE_PLATFORMS)
|
|
68
|
+
if platform == "hf" or platform == "huggingface":
|
|
69
|
+
return datasets_load_dataset(name, split=split)
|
|
70
|
+
elif platform == "modelscope":
|
|
71
|
+
dataset_dir = modelscope_snapshot_download(name, repo_type="dataset")
|
|
72
|
+
return datasets_load_dataset(dataset_dir, split=split)
|
|
73
|
+
else:
|
|
74
|
+
_raise_unknown_platform_error(platform)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def resolve_repo_path(
|
|
78
|
+
repo_id: str,
|
|
79
|
+
repo_type: Optional[str] = "model",
|
|
80
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
81
|
+
**kwargs,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Resolve and download a repository from various platforms to a local path.
|
|
85
|
+
|
|
86
|
+
This function handles multiple repository sources including local paths, Hugging Face,
|
|
87
|
+
and ModelScope. It automatically downloads remote repositories to local cache and
|
|
88
|
+
returns the local path for further use.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
repo_id (str): Repository identifier. Can be:
|
|
92
|
+
- Local file/directory path (returned as-is if exists)
|
|
93
|
+
- Hugging Face model/dataset ID (e.g., "bert-base-uncased")
|
|
94
|
+
- ModelScope model/dataset ID
|
|
95
|
+
- URL-prefixed ID (e.g., "hf://model-name", "modelscope://model-name").
|
|
96
|
+
The prefix will override the platform argument.
|
|
97
|
+
repo_type (str, optional): Type of repository to download. Defaults to "model".
|
|
98
|
+
Common values include "model" and "dataset".
|
|
99
|
+
platform (Literal["hf", "huggingface", "modelscope"], optional):
|
|
100
|
+
Platform to download from. Defaults to "hf". Options:
|
|
101
|
+
- "hf" or "huggingface": Hugging Face Hub
|
|
102
|
+
- "modelscope": ModelScope platform
|
|
103
|
+
**kwargs: Additional arguments passed to the underlying download functions.
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
str: Local path to the repository (either existing local path or downloaded cache path).
|
|
107
|
+
|
|
108
|
+
Raises:
|
|
109
|
+
FileNotFoundError: If the repository cannot be found or downloaded from any platform.
|
|
110
|
+
ValueError: If an unsupported platform is specified.
|
|
111
|
+
ImportError: If required dependencies for the specified platform are not installed.
|
|
112
|
+
|
|
113
|
+
Examples:
|
|
114
|
+
>>> # Local path (returned as-is)
|
|
115
|
+
>>> resolve_repo_path("/path/to/local/model")
|
|
116
|
+
"/path/to/local/model"
|
|
117
|
+
|
|
118
|
+
>>> # Hugging Face model
|
|
119
|
+
>>> resolve_repo_path("bert-base-uncased")
|
|
120
|
+
"/home/user/.cache/huggingface/hub/models--bert-base-uncased/..."
|
|
121
|
+
|
|
122
|
+
>>> # ModelScope model with explicit platform
|
|
123
|
+
>>> resolve_repo_path("damo/nlp_bert_backbone_base_std", platform="modelscope")
|
|
124
|
+
"/home/user/.cache/modelscope/hub/damo/nlp_bert_backbone_base_std/..."
|
|
125
|
+
|
|
126
|
+
>>> # URL-prefixed repository ID
|
|
127
|
+
>>> resolve_repo_path("hf://microsoft/DialoGPT-medium")
|
|
128
|
+
"/home/user/.cache/huggingface/hub/models--microsoft--DialoGPT-medium/..."
|
|
129
|
+
"""
|
|
130
|
+
# If it's a HuggingFace Hub model id, download snapshot
|
|
131
|
+
if repo_id.startswith("hf://") or repo_id.startswith("huggingface://"):
|
|
132
|
+
repo_id = repo_id.replace("hf://", "").replace("huggingface://", "")
|
|
133
|
+
platform = "hf"
|
|
134
|
+
# If it's a ModelScope model id, download snapshot
|
|
135
|
+
elif repo_id.startswith("modelscope://"):
|
|
136
|
+
repo_id = repo_id.replace("modelscope://", "")
|
|
137
|
+
platform = "modelscope"
|
|
138
|
+
|
|
139
|
+
# If it's a local file or directory, return as is
|
|
140
|
+
if os.path.exists(repo_id):
|
|
141
|
+
return repo_id
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
validate_and_suggest_corrections(platform, AVAILABLE_PLATFORMS)
|
|
145
|
+
# This will download the model to the cache and return the local path
|
|
146
|
+
if platform in ["hf", "huggingface"]:
|
|
147
|
+
local_path = huggingface_snapshot_download(
|
|
148
|
+
repo_id=repo_id, repo_type=repo_type, **kwargs
|
|
149
|
+
)
|
|
150
|
+
elif platform == "modelscope":
|
|
151
|
+
local_path = modelscope_snapshot_download(
|
|
152
|
+
repo_id=repo_id, repo_type=repo_type, **kwargs
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
_raise_unknown_platform_error(platform)
|
|
156
|
+
return local_path
|
|
157
|
+
except Exception as e:
|
|
158
|
+
raise FileNotFoundError(f"Could not resolve checkpoint: {repo_id}. Error: {e}")
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def resolve_file_path(
|
|
162
|
+
repo_id: str,
|
|
163
|
+
filename: str,
|
|
164
|
+
repo_type: Literal["model", "dataset"] = "model",
|
|
165
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
166
|
+
**kwargs,
|
|
167
|
+
) -> str:
|
|
168
|
+
"""
|
|
169
|
+
Resolve and download a specific file from a repository across multiple platforms.
|
|
170
|
+
|
|
171
|
+
This function downloads a specific file from repositories hosted on various platforms
|
|
172
|
+
including local paths, Hugging Face Hub, and ModelScope. It handles platform-specific
|
|
173
|
+
URL prefixes and automatically determines the appropriate download method.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
repo_id (str): Repository identifier. Can be:
|
|
177
|
+
- Local directory path (file will be joined with this path if it exists)
|
|
178
|
+
- Hugging Face model/dataset ID (e.g., "bert-base-uncased")
|
|
179
|
+
- ModelScope model/dataset ID
|
|
180
|
+
- URL-prefixed ID (e.g., "hf://model-name", "modelscope://model-name").
|
|
181
|
+
The prefix will override the platform argument.
|
|
182
|
+
filename (str): The specific file to download from the repository.
|
|
183
|
+
repo_type (Literal["model", "dataset"], optional): Type of repository.
|
|
184
|
+
Defaults to "model". Used for ModelScope platform to determine the
|
|
185
|
+
correct download function.
|
|
186
|
+
platform (Literal["hf", "huggingface", "modelscope"], optional):
|
|
187
|
+
Platform to download from. Defaults to "hf". Options:
|
|
188
|
+
- "hf" or "huggingface": Hugging Face Hub
|
|
189
|
+
- "modelscope": ModelScope platform
|
|
190
|
+
**kwargs: Additional arguments passed to the underlying download functions
|
|
191
|
+
(e.g., cache_dir, force_download, use_auth_token).
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
str: Local path to the downloaded file.
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
ValueError: If an unsupported repo_type is specified for ModelScope platform.
|
|
198
|
+
ImportError: If required dependencies for the specified platform are not installed.
|
|
199
|
+
FileNotFoundError: If the file cannot be found or downloaded.
|
|
200
|
+
|
|
201
|
+
Examples:
|
|
202
|
+
>>> # Download config.json from a Hugging Face model
|
|
203
|
+
>>> resolve_file_path("bert-base-uncased", "config.json")
|
|
204
|
+
"/home/user/.cache/huggingface/hub/models--bert-base-uncased/.../config.json"
|
|
205
|
+
|
|
206
|
+
>>> # Download from ModelScope
|
|
207
|
+
>>> resolve_file_path(
|
|
208
|
+
... "damo/nlp_bert_backbone_base_std",
|
|
209
|
+
... "pytorch_model.bin",
|
|
210
|
+
... platform="modelscope"
|
|
211
|
+
... )
|
|
212
|
+
"/home/user/.cache/modelscope/hub/.../pytorch_model.bin"
|
|
213
|
+
|
|
214
|
+
>>> # Local file path
|
|
215
|
+
>>> resolve_file_path("/path/to/local/model", "config.json")
|
|
216
|
+
"/path/to/local/model/config.json"
|
|
217
|
+
|
|
218
|
+
>>> # URL-prefixed repository
|
|
219
|
+
>>> resolve_file_path("hf://microsoft/DialoGPT-medium", "config.json")
|
|
220
|
+
"/home/user/.cache/huggingface/hub/.../config.json"
|
|
221
|
+
|
|
222
|
+
>>> # Download dataset file from ModelScope
|
|
223
|
+
>>> resolve_file_path(
|
|
224
|
+
... "DAMO_NLP/jd",
|
|
225
|
+
... "train.json",
|
|
226
|
+
... repo_type="dataset",
|
|
227
|
+
... platform="modelscope"
|
|
228
|
+
... )
|
|
229
|
+
"/home/user/.cache/modelscope/datasets/.../train.json"
|
|
230
|
+
"""
|
|
231
|
+
# If it's a HuggingFace Hub model id, download snapshot
|
|
232
|
+
if repo_id.startswith("hf://") or repo_id.startswith("huggingface://"):
|
|
233
|
+
repo_id = repo_id.replace("hf://", "").replace("huggingface://", "")
|
|
234
|
+
platform = "hf"
|
|
235
|
+
# If it's a ModelScope model id, download snapshot
|
|
236
|
+
elif repo_id.startswith("modelscope://"):
|
|
237
|
+
repo_id = repo_id.replace("modelscope://", "")
|
|
238
|
+
platform = "modelscope"
|
|
239
|
+
|
|
240
|
+
# If it's a local file or directory, return as is
|
|
241
|
+
if os.path.exists(repo_id):
|
|
242
|
+
return os.path.join(repo_id, filename)
|
|
243
|
+
|
|
244
|
+
if platform in ["hf", "huggingface"]:
|
|
245
|
+
return hf_hub_download(
|
|
246
|
+
repo_id=repo_id,
|
|
247
|
+
filename=filename,
|
|
248
|
+
repo_type=repo_type,
|
|
249
|
+
**kwargs,
|
|
250
|
+
)
|
|
251
|
+
elif platform == "modelscope":
|
|
252
|
+
if repo_type == "model":
|
|
253
|
+
return modelscope_model_file_download(
|
|
254
|
+
model_id=repo_id, file_path=filename, **kwargs
|
|
255
|
+
)
|
|
256
|
+
elif repo_type == "dataset":
|
|
257
|
+
return modelscope_dataset_file_download(
|
|
258
|
+
dataset_id=repo_id, file_path=filename, **kwargs
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
raise ValueError(
|
|
262
|
+
f"Unsupported repo_type: {repo_type}. Supported types are 'model' and 'dataset'."
|
|
263
|
+
)
|
|
264
|
+
else:
|
|
265
|
+
_raise_unknown_platform_error(platform)
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import copy
|
|
2
2
|
from collections import OrderedDict
|
|
3
|
-
from typing import List, Mapping, Optional, Union
|
|
3
|
+
from typing import Dict, List, Mapping, Optional, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from torch import nn
|
|
@@ -83,7 +83,7 @@ def vector_to_state_dict(
|
|
|
83
83
|
vector: torch.Tensor,
|
|
84
84
|
state_dict: Union[StateDictType, nn.Module],
|
|
85
85
|
remove_keys: Optional[List[str]] = None,
|
|
86
|
-
):
|
|
86
|
+
) -> Dict[str, torch.Tensor]:
|
|
87
87
|
"""
|
|
88
88
|
Convert a vector to a state dictionary.
|
|
89
89
|
|
fusion_bench/utils/rich_utils.py
CHANGED