fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
fusion_bench/scripts/cli.py
CHANGED
|
@@ -12,6 +12,7 @@ import os
|
|
|
12
12
|
import hydra
|
|
13
13
|
from omegaconf import DictConfig, OmegaConf
|
|
14
14
|
|
|
15
|
+
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
15
16
|
from fusion_bench.programs import BaseHydraProgram
|
|
16
17
|
from fusion_bench.utils import instantiate
|
|
17
18
|
|
|
@@ -20,11 +21,10 @@ log = logging.getLogger(__name__)
|
|
|
20
21
|
|
|
21
22
|
def _get_default_config_path():
|
|
22
23
|
for config_dir in ["fusion_bench_config", "config"]:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
return os.path.abspath(config_path)
|
|
24
|
+
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
25
|
+
config_path = os.path.join(config_path_root, config_dir)
|
|
26
|
+
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
27
|
+
return os.path.abspath(config_path)
|
|
28
28
|
return None
|
|
29
29
|
|
|
30
30
|
|
|
@@ -34,6 +34,39 @@ def _get_default_config_path():
|
|
|
34
34
|
version_base=None,
|
|
35
35
|
)
|
|
36
36
|
def main(cfg: DictConfig) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Main entry point for the FusionBench command-line interface.
|
|
39
|
+
|
|
40
|
+
This function serves as the primary entry point for the `fusion_bench` CLI command.
|
|
41
|
+
It is decorated with Hydra's main decorator to handle configuration management,
|
|
42
|
+
command-line argument parsing, and configuration file loading.
|
|
43
|
+
|
|
44
|
+
The function performs the following operations:
|
|
45
|
+
1. Resolves any interpolations in the configuration using OmegaConf
|
|
46
|
+
2. Instantiates the appropriate program class based on the configuration
|
|
47
|
+
3. Executes the program's run method to perform the fusion task
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
cfg (DictConfig): The Hydra configuration object containing all settings
|
|
51
|
+
for the fusion task. This includes method configuration, model pool
|
|
52
|
+
configuration, task pool configuration, and other runtime parameters.
|
|
53
|
+
The configuration is automatically loaded by Hydra from the specified
|
|
54
|
+
config files and command-line overrides.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
None: This function doesn't return a value but executes the fusion
|
|
58
|
+
program which may save results, log outputs, or perform other
|
|
59
|
+
side effects as configured.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
This function is typically called automatically when running:
|
|
63
|
+
```bash
|
|
64
|
+
fusion_bench method=... modelpool=... taskpool=...
|
|
65
|
+
```
|
|
66
|
+
|
|
67
|
+
The Hydra decorator handles parsing these command-line arguments and
|
|
68
|
+
loading the corresponding configuration files to populate the cfg parameter.
|
|
69
|
+
"""
|
|
37
70
|
OmegaConf.resolve(cfg)
|
|
38
71
|
program: BaseHydraProgram = instantiate(cfg)
|
|
39
72
|
program.run()
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
+
from typing import Any, Dict
|
|
2
3
|
|
|
3
|
-
from fusion_bench.mixins import
|
|
4
|
+
from fusion_bench.mixins import BaseYAMLSerializable
|
|
4
5
|
|
|
5
6
|
|
|
6
|
-
class BaseTaskPool(
|
|
7
|
+
class BaseTaskPool(BaseYAMLSerializable):
|
|
7
8
|
_program = None
|
|
8
9
|
_config_key = "taskpool"
|
|
9
10
|
|
|
10
11
|
@abstractmethod
|
|
11
|
-
def evaluate(self, model, *args, **kwargs):
|
|
12
|
+
def evaluate(self, model: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
|
12
13
|
"""
|
|
13
14
|
Evaluate the model on all tasks in the task pool, and return a report.
|
|
14
15
|
|
|
@@ -27,8 +27,9 @@ from tqdm.autonotebook import tqdm
|
|
|
27
27
|
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
28
28
|
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
29
29
|
|
|
30
|
+
from fusion_bench import RuntimeConstants
|
|
30
31
|
from fusion_bench.dataset import CLIPDataset
|
|
31
|
-
from fusion_bench.mixins import LightningFabricMixin
|
|
32
|
+
from fusion_bench.mixins import HydraConfigMixin, LightningFabricMixin
|
|
32
33
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
33
34
|
from fusion_bench.taskpool import BaseTaskPool
|
|
34
35
|
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
@@ -56,6 +57,8 @@ class LayerWiseFeatureSaver:
|
|
|
56
57
|
first_token_only: bool = True,
|
|
57
58
|
max_num: Optional[int] = None,
|
|
58
59
|
):
|
|
60
|
+
if isinstance(save_path, str):
|
|
61
|
+
save_path = Path(save_path)
|
|
59
62
|
self.save_path = save_path
|
|
60
63
|
self.first_token_only = first_token_only
|
|
61
64
|
self.max_num = max_num
|
|
@@ -84,8 +87,9 @@ class LayerWiseFeatureSaver:
|
|
|
84
87
|
|
|
85
88
|
|
|
86
89
|
class CLIPVisionModelTaskPool(
|
|
87
|
-
|
|
90
|
+
HydraConfigMixin,
|
|
88
91
|
LightningFabricMixin,
|
|
92
|
+
BaseTaskPool,
|
|
89
93
|
):
|
|
90
94
|
"""
|
|
91
95
|
This class is used to define the image classification task for CLIP models.
|
|
@@ -122,14 +126,14 @@ class CLIPVisionModelTaskPool(
|
|
|
122
126
|
self,
|
|
123
127
|
test_datasets: Union[DictConfig, Dict[str, Dataset]],
|
|
124
128
|
*,
|
|
125
|
-
processor: Union[DictConfig, CLIPProcessor],
|
|
126
|
-
|
|
127
|
-
|
|
129
|
+
processor: Union[str, DictConfig, CLIPProcessor],
|
|
130
|
+
clip_model: Union[str, DictConfig, CLIPModel],
|
|
131
|
+
data_processor: Union[DictConfig, CLIPProcessor] = None,
|
|
128
132
|
dataloader_kwargs: DictConfig = None,
|
|
129
133
|
layer_wise_feature_save_path: Optional[str] = None,
|
|
130
134
|
layer_wise_feature_first_token_only: bool = True,
|
|
131
135
|
layer_wise_feature_max_num: Optional[int] = None,
|
|
132
|
-
fast_dev_run: bool =
|
|
136
|
+
fast_dev_run: Optional[bool] = None,
|
|
133
137
|
**kwargs,
|
|
134
138
|
):
|
|
135
139
|
"""
|
|
@@ -151,7 +155,10 @@ class CLIPVisionModelTaskPool(
|
|
|
151
155
|
self.layer_wise_feature_first_token_only = layer_wise_feature_first_token_only
|
|
152
156
|
self.layer_wise_feature_max_num = layer_wise_feature_max_num
|
|
153
157
|
|
|
154
|
-
self.fast_dev_run
|
|
158
|
+
if self.fast_dev_run is None:
|
|
159
|
+
self.fast_dev_run = RuntimeConstants().debug
|
|
160
|
+
else:
|
|
161
|
+
self.fast_dev_run = fast_dev_run
|
|
155
162
|
super().__init__(**kwargs)
|
|
156
163
|
|
|
157
164
|
def setup(self):
|
|
@@ -159,21 +166,35 @@ class CLIPVisionModelTaskPool(
|
|
|
159
166
|
Set up the processor, data processor, CLIP model, test datasets, and data loaders.
|
|
160
167
|
"""
|
|
161
168
|
# setup processor and clip model
|
|
162
|
-
self.
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
instantiate(self.
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
self.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
169
|
+
if isinstance(self._processor, str):
|
|
170
|
+
self.processor = CLIPProcessor.from_pretrained(self._processor)
|
|
171
|
+
elif (
|
|
172
|
+
isinstance(self._processor, (dict, DictConfig))
|
|
173
|
+
and "_target_" in self._processor
|
|
174
|
+
):
|
|
175
|
+
self.processor = instantiate(self._processor)
|
|
176
|
+
else:
|
|
177
|
+
self.processor = self._processor
|
|
178
|
+
|
|
179
|
+
if self._data_processor is None:
|
|
180
|
+
self.data_processor = self.processor
|
|
181
|
+
else:
|
|
182
|
+
self.data_processor = (
|
|
183
|
+
instantiate(self._data_processor)
|
|
184
|
+
if isinstance(self._data_processor, DictConfig)
|
|
185
|
+
else self._data_processor
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
if isinstance(self._clip_model, str):
|
|
189
|
+
self.clip_model = CLIPModel.from_pretrained(self._clip_model)
|
|
190
|
+
elif (
|
|
191
|
+
isinstance(self._clip_model, (dict, DictConfig))
|
|
192
|
+
and "_target_" in self._clip_model
|
|
193
|
+
):
|
|
194
|
+
self.clip_model = instantiate(self._clip_model)
|
|
195
|
+
else:
|
|
196
|
+
self.clip_model = self._clip_model
|
|
197
|
+
|
|
177
198
|
self.clip_model = self.fabric.to_device(self.clip_model)
|
|
178
199
|
self.clip_model.requires_grad_(False)
|
|
179
200
|
self.clip_model.eval()
|
fusion_bench/taskpool/dummy.py
CHANGED
|
@@ -4,13 +4,13 @@ This is the dummy task pool that is used for debugging purposes.
|
|
|
4
4
|
|
|
5
5
|
from typing import Optional
|
|
6
6
|
|
|
7
|
+
from lightning.pytorch.utilities import rank_zero_only
|
|
7
8
|
from torch import nn
|
|
8
9
|
|
|
9
10
|
from fusion_bench.models.separate_io import separate_save
|
|
10
11
|
from fusion_bench.taskpool.base_pool import BaseTaskPool
|
|
11
12
|
from fusion_bench.utils import timeit_context
|
|
12
13
|
from fusion_bench.utils.parameters import count_parameters, print_parameters
|
|
13
|
-
from lightning.pytorch.utilities import rank_zero_only
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def get_model_summary(model: nn.Module) -> dict:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
|
-
from typing import List, Literal, Optional, Union
|
|
3
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
import lightning.fabric
|
|
6
6
|
import lm_eval
|
|
@@ -12,7 +12,6 @@ from fusion_bench import BaseTaskPool
|
|
|
12
12
|
from fusion_bench.mixins import LightningFabricMixin
|
|
13
13
|
from fusion_bench.utils.strenum import _version
|
|
14
14
|
|
|
15
|
-
|
|
16
15
|
log = logging.getLogger(__name__)
|
|
17
16
|
|
|
18
17
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import importlib
|
|
2
2
|
import warnings
|
|
3
|
-
from typing import Any, Callable, Dict, List
|
|
3
|
+
from typing import Any, Callable, Dict, List, Tuple
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
6
|
|
|
@@ -79,7 +79,9 @@ class CLIPTemplateFactory:
|
|
|
79
79
|
}
|
|
80
80
|
|
|
81
81
|
@staticmethod
|
|
82
|
-
def get_classnames_and_templates(
|
|
82
|
+
def get_classnames_and_templates(
|
|
83
|
+
dataset_name: str,
|
|
84
|
+
) -> Tuple[List[str], List[Callable]]:
|
|
83
85
|
"""
|
|
84
86
|
Retrieves class names and templates for the specified dataset.
|
|
85
87
|
|
|
@@ -169,7 +171,7 @@ class CLIPTemplateFactory:
|
|
|
169
171
|
CLIPTemplateFactory._dataset_mapping[dataset_name] = dataset_info
|
|
170
172
|
|
|
171
173
|
@staticmethod
|
|
172
|
-
def get_available_datasets():
|
|
174
|
+
def get_available_datasets() -> List[str]:
|
|
173
175
|
"""
|
|
174
176
|
Get a list of all available dataset names.
|
|
175
177
|
|
|
@@ -179,5 +181,5 @@ class CLIPTemplateFactory:
|
|
|
179
181
|
return list(CLIPTemplateFactory._dataset_mapping.keys())
|
|
180
182
|
|
|
181
183
|
|
|
182
|
-
def get_classnames_and_templates(dataset_name: str):
|
|
184
|
+
def get_classnames_and_templates(dataset_name: str) -> Tuple[List[str], List[Callable]]:
|
|
183
185
|
return CLIPTemplateFactory.get_classnames_and_templates(dataset_name)
|
fusion_bench/utils/__init__.py
CHANGED
|
@@ -7,10 +7,16 @@ from .cache_utils import *
|
|
|
7
7
|
from .devices import *
|
|
8
8
|
from .dtype import parse_dtype
|
|
9
9
|
from .fabric import seed_everything_by_time
|
|
10
|
-
from .instantiate_utils import
|
|
10
|
+
from .instantiate_utils import (
|
|
11
|
+
instantiate,
|
|
12
|
+
is_instantiable,
|
|
13
|
+
set_print_function_call,
|
|
14
|
+
set_print_function_call_permeanent,
|
|
15
|
+
)
|
|
11
16
|
from .json import load_from_json, save_to_json
|
|
12
17
|
from .lazy_state_dict import LazyStateDict
|
|
13
18
|
from .misc import *
|
|
14
19
|
from .packages import import_object
|
|
15
20
|
from .parameters import *
|
|
21
|
+
from .pylogger import get_rankzero_logger
|
|
16
22
|
from .timer import timeit_context
|
|
@@ -1,15 +1,30 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
import pickle
|
|
4
|
+
import warnings
|
|
4
5
|
from functools import wraps
|
|
5
6
|
from pathlib import Path
|
|
6
7
|
from typing import Any, Callable, Union
|
|
7
8
|
|
|
8
|
-
|
|
9
|
+
from joblib import Memory
|
|
10
|
+
|
|
11
|
+
__all__ = ["cache_to_disk", "cache_with_joblib", "set_default_cache_dir"]
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
log = logging.getLogger(__name__)
|
|
12
15
|
|
|
16
|
+
DEFAULT_CACHE_DIR = Path.cwd() / "outputs" / "cache"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def set_default_cache_dir(path: str | Path):
|
|
20
|
+
global DEFAULT_CACHE_DIR
|
|
21
|
+
if path is None:
|
|
22
|
+
return
|
|
23
|
+
|
|
24
|
+
if isinstance(path, str):
|
|
25
|
+
path = Path(path)
|
|
26
|
+
DEFAULT_CACHE_DIR = path
|
|
27
|
+
|
|
13
28
|
|
|
14
29
|
def cache_to_disk(file_path: Union[str, Path]) -> Callable:
|
|
15
30
|
"""
|
|
@@ -17,6 +32,11 @@ def cache_to_disk(file_path: Union[str, Path]) -> Callable:
|
|
|
17
32
|
the result is loaded from the file. Otherwise, the function is executed and
|
|
18
33
|
the result is saved to the file.
|
|
19
34
|
|
|
35
|
+
!!! warning "deprecated"
|
|
36
|
+
This function is deprecated. Use `cache_with_joblib` instead for better
|
|
37
|
+
caching capabilities including automatic cache invalidation, better object
|
|
38
|
+
handling, and memory efficiency.
|
|
39
|
+
|
|
20
40
|
## Example usage
|
|
21
41
|
|
|
22
42
|
```python
|
|
@@ -32,6 +52,13 @@ def cache_to_disk(file_path: Union[str, Path]) -> Callable:
|
|
|
32
52
|
Returns:
|
|
33
53
|
Callable: The decorated function.
|
|
34
54
|
"""
|
|
55
|
+
warnings.warn(
|
|
56
|
+
"cache_to_disk is deprecated. Use cache_with_joblib instead for better "
|
|
57
|
+
"caching capabilities including automatic cache invalidation, better object "
|
|
58
|
+
"handling, and memory efficiency.",
|
|
59
|
+
DeprecationWarning,
|
|
60
|
+
stacklevel=2,
|
|
61
|
+
)
|
|
35
62
|
if isinstance(file_path, str):
|
|
36
63
|
file_path = Path(file_path)
|
|
37
64
|
assert isinstance(file_path, Path)
|
|
@@ -56,3 +83,76 @@ def cache_to_disk(file_path: Union[str, Path]) -> Callable:
|
|
|
56
83
|
return wrapper
|
|
57
84
|
|
|
58
85
|
return decorator
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def cache_with_joblib(
|
|
89
|
+
cache_dir: Union[str, Path] = None,
|
|
90
|
+
verbose: int = 0,
|
|
91
|
+
) -> Callable:
|
|
92
|
+
"""
|
|
93
|
+
A decorator to cache the result of a function using joblib.Memory. This provides
|
|
94
|
+
more advanced caching capabilities compared to cache_to_disk, including:
|
|
95
|
+
- Automatic cache invalidation when function arguments change
|
|
96
|
+
- Better handling of numpy arrays and other complex objects
|
|
97
|
+
- Memory-efficient storage
|
|
98
|
+
- Optional verbose output for cache hits/misses
|
|
99
|
+
|
|
100
|
+
## Example usage
|
|
101
|
+
|
|
102
|
+
```python
|
|
103
|
+
@cache_with_joblib("./cache", verbose=1)
|
|
104
|
+
def expensive_computation(x: int, y: str) -> Any:
|
|
105
|
+
# Function implementation
|
|
106
|
+
return complex_result
|
|
107
|
+
|
|
108
|
+
# Or with default settings:
|
|
109
|
+
@cache_with_joblib()
|
|
110
|
+
def another_function(x: int) -> int:
|
|
111
|
+
return x * 2
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
cache_dir (Union[str, Path]): The directory where cache files should be stored.
|
|
116
|
+
If `None`, a default directory `outputs/cache` will be used.
|
|
117
|
+
verbose (int): Verbosity level for joblib.Memory (0=silent, 1=basic, 2++=verbose).
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
Callable: A decorator function that can be applied to functions.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
if cache_dir is None:
|
|
124
|
+
cache_dir = DEFAULT_CACHE_DIR
|
|
125
|
+
|
|
126
|
+
if isinstance(cache_dir, str):
|
|
127
|
+
cache_dir = Path(cache_dir)
|
|
128
|
+
assert isinstance(cache_dir, Path)
|
|
129
|
+
|
|
130
|
+
# Create the cache directory if it doesn't exist
|
|
131
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
132
|
+
|
|
133
|
+
# Create a Memory object for this function
|
|
134
|
+
memory = Memory(location=cache_dir, verbose=verbose)
|
|
135
|
+
|
|
136
|
+
def decorator(func: Callable) -> Callable:
|
|
137
|
+
nonlocal memory
|
|
138
|
+
|
|
139
|
+
# Create the cached version of the function
|
|
140
|
+
cached_func = memory.cache(func)
|
|
141
|
+
|
|
142
|
+
@wraps(func)
|
|
143
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
144
|
+
return cached_func(*args, **kwargs)
|
|
145
|
+
|
|
146
|
+
# Expose useful methods from joblib.Memory
|
|
147
|
+
if not (
|
|
148
|
+
hasattr(cached_func, "clear")
|
|
149
|
+
or hasattr(cached_func, "call")
|
|
150
|
+
or hasattr(cached_func, "check_call_in_cache")
|
|
151
|
+
):
|
|
152
|
+
wrapper.clear = cached_func.clear
|
|
153
|
+
wrapper.call = cached_func.call
|
|
154
|
+
wrapper.check_call_in_cache = cached_func.check_call_in_cache
|
|
155
|
+
|
|
156
|
+
return wrapper
|
|
157
|
+
|
|
158
|
+
return decorator
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import gc
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import List, Optional, Union
|
|
4
5
|
|
|
@@ -12,7 +13,7 @@ from transformers.utils import (
|
|
|
12
13
|
)
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
15
|
-
"
|
|
16
|
+
"clear_cuda_cache",
|
|
16
17
|
"to_device",
|
|
17
18
|
"num_devices",
|
|
18
19
|
"get_device",
|
|
@@ -21,10 +22,19 @@ __all__ = [
|
|
|
21
22
|
"get_device_capabilities",
|
|
22
23
|
]
|
|
23
24
|
|
|
25
|
+
log = logging.getLogger(__name__)
|
|
24
26
|
|
|
25
|
-
|
|
27
|
+
|
|
28
|
+
def clear_cuda_cache():
|
|
29
|
+
"""
|
|
30
|
+
Clears the CUDA memory cache to free up GPU memory.
|
|
31
|
+
Works only if CUDA is available.
|
|
32
|
+
"""
|
|
26
33
|
gc.collect()
|
|
27
|
-
torch.cuda.
|
|
34
|
+
if torch.cuda.is_available():
|
|
35
|
+
torch.cuda.empty_cache()
|
|
36
|
+
else:
|
|
37
|
+
log.warning("CUDA is not available. No cache to clear.")
|
|
28
38
|
|
|
29
39
|
|
|
30
40
|
def to_device(obj, device: Optional[torch.device], **kwargs):
|
|
@@ -75,7 +85,7 @@ def num_devices(devices: Union[int, List[int], str]) -> int:
|
|
|
75
85
|
Return the number of devices.
|
|
76
86
|
|
|
77
87
|
Args:
|
|
78
|
-
devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3]
|
|
88
|
+
devices: `devices` can be a single int to specify the number of devices, or a list of device ids, e.g. [0, 1, 2, 3], or a str of device ids, e.g. "0,1,2,3" and "[0, 1, 2]".
|
|
79
89
|
|
|
80
90
|
Returns:
|
|
81
91
|
The number of devices.
|
fusion_bench/utils/fabric.py
CHANGED
|
@@ -3,9 +3,9 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
import lightning as L
|
|
5
5
|
|
|
6
|
-
from fusion_bench.utils.pylogger import
|
|
6
|
+
from fusion_bench.utils.pylogger import get_rankzero_logger
|
|
7
7
|
|
|
8
|
-
log =
|
|
8
|
+
log = get_rankzero_logger(__name__)
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
def seed_everything_by_time(fabric: Optional[L.Fabric] = None):
|
|
@@ -28,7 +28,7 @@ PRINT_FUNCTION_CALL_FUNC = print
|
|
|
28
28
|
Function to be used for printing function calls.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
|
-
CATCH_EXCEPTION =
|
|
31
|
+
CATCH_EXCEPTION = False
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
@contextmanager
|
|
@@ -41,10 +41,12 @@ def set_print_function_call(value: bool):
|
|
|
41
41
|
finally:
|
|
42
42
|
PRINT_FUNCTION_CALL = old_value
|
|
43
43
|
|
|
44
|
+
|
|
44
45
|
def set_print_function_call_permeanent(value: bool):
|
|
45
46
|
global PRINT_FUNCTION_CALL
|
|
46
47
|
PRINT_FUNCTION_CALL = value
|
|
47
48
|
|
|
49
|
+
|
|
48
50
|
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
49
51
|
if OmegaConf.is_dict(config):
|
|
50
52
|
return "_target_" in config
|
|
@@ -72,3 +72,26 @@ class LazyImporter(ModuleType):
|
|
|
72
72
|
|
|
73
73
|
def __reduce__(self):
|
|
74
74
|
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class LazyModule(ModuleType):
|
|
78
|
+
"""Module wrapper for lazy import.
|
|
79
|
+
Adapted from Optuna: https://github.com/optuna/optuna/blob/1f92d496b0c4656645384e31539e4ee74992ff55/optuna/__init__.py
|
|
80
|
+
|
|
81
|
+
This class wraps specified module and lazily import it when they are actually accessed.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
name: Name of module to apply lazy import.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self, name: str) -> None:
|
|
88
|
+
super().__init__(name)
|
|
89
|
+
self._name = name
|
|
90
|
+
|
|
91
|
+
def _load(self) -> ModuleType:
|
|
92
|
+
module = importlib.import_module(self._name)
|
|
93
|
+
self.__dict__.update(module.__dict__)
|
|
94
|
+
return module
|
|
95
|
+
|
|
96
|
+
def __getattr__(self, item: str) -> Any:
|
|
97
|
+
return getattr(self._load(), item)
|
|
@@ -2,7 +2,7 @@ import json
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
-
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple, Type
|
|
5
|
+
from typing import TYPE_CHECKING, Dict, Iterator, List, Mapping, Optional, Tuple, Type
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from accelerate import init_empty_weights
|
|
@@ -49,7 +49,7 @@ def resolve_checkpoint_path(
|
|
|
49
49
|
)
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
class LazyStateDict:
|
|
52
|
+
class LazyStateDict(Mapping[str, torch.Tensor]):
|
|
53
53
|
"""
|
|
54
54
|
Dictionary-like object that lazily loads a state dict from a checkpoint path.
|
|
55
55
|
"""
|
|
@@ -168,12 +168,21 @@ class LazyStateDict:
|
|
|
168
168
|
def config(self) -> "PretrainedConfig":
|
|
169
169
|
return AutoConfig.from_pretrained(self._checkpoint)
|
|
170
170
|
|
|
171
|
+
@property
|
|
172
|
+
def dtype(self) -> torch.dtype:
|
|
173
|
+
"""
|
|
174
|
+
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
|
175
|
+
"""
|
|
176
|
+
first_key = next(iter(self.keys()))
|
|
177
|
+
first_param = self[first_key]
|
|
178
|
+
return first_param.dtype
|
|
179
|
+
|
|
171
180
|
def state_dict(self, keep_vars: bool = False) -> "LazyStateDict":
|
|
172
181
|
"""
|
|
173
182
|
Args:
|
|
174
183
|
keep_vars (bool): Ignored, as LazyStateDict does not support keep_vars. Just for compatibility.
|
|
175
184
|
"""
|
|
176
|
-
return self
|
|
185
|
+
return deepcopy(self)
|
|
177
186
|
|
|
178
187
|
def _resolve_checkpoint_files(self, checkpoint: str):
|
|
179
188
|
# reference: https://huggingface.co/docs/accelerate/v0.17.1/en/usage_guides/big_modeling
|
|
@@ -290,6 +299,18 @@ class LazyStateDict:
|
|
|
290
299
|
)
|
|
291
300
|
return tensor
|
|
292
301
|
|
|
302
|
+
def pop(self, key: str):
|
|
303
|
+
assert key in list(
|
|
304
|
+
self.keys()
|
|
305
|
+
), "KeyError: Cannot pop a tensor for a key that does not exist in the LazyStateDict."
|
|
306
|
+
if self._state_dict_cache is not None and key in self._state_dict_cache:
|
|
307
|
+
if key in self._index:
|
|
308
|
+
self._index.pop(key)
|
|
309
|
+
return self._state_dict_cache.pop(key)
|
|
310
|
+
if key in self._index:
|
|
311
|
+
self._index.pop(key)
|
|
312
|
+
return None
|
|
313
|
+
|
|
293
314
|
def __setitem__(self, key: str, value: torch.Tensor) -> None:
|
|
294
315
|
"""
|
|
295
316
|
Set a tensor in the LazyStateDict. This will update the state dict cache if it is enabled.
|
|
@@ -408,3 +429,17 @@ class LazyStateDict:
|
|
|
408
429
|
raise KeyError(f"Key {key} not found in LazyStateDict.")
|
|
409
430
|
for key, value in state_dict.items():
|
|
410
431
|
self[key] = value
|
|
432
|
+
|
|
433
|
+
def __getattr__(self, name: str):
|
|
434
|
+
if "meta_module" in self.__dict__:
|
|
435
|
+
meta_module = self.__dict__["meta_module"]
|
|
436
|
+
if meta_module is not None:
|
|
437
|
+
if "_parameters" in meta_module.__dict__:
|
|
438
|
+
if name in meta_module.__dict__["_parameters"]:
|
|
439
|
+
return self.get_parameter(name)
|
|
440
|
+
if "_modules" in meta_module.__dict__:
|
|
441
|
+
if name in meta_module.__dict__["_modules"]:
|
|
442
|
+
return self.get_submodule(name)
|
|
443
|
+
raise AttributeError(
|
|
444
|
+
f"'{type(self).__name__}' object has no attribute '{name}'"
|
|
445
|
+
)
|