fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -1,20 +1,148 @@
|
|
|
1
|
+
import inspect
|
|
1
2
|
import logging
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from functools import wraps
|
|
5
|
+
from inspect import Parameter, _ParameterKind
|
|
2
6
|
from pathlib import Path
|
|
3
7
|
from typing import Dict, Optional, Union
|
|
4
8
|
|
|
5
|
-
from omegaconf import OmegaConf
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
|
6
10
|
|
|
11
|
+
from fusion_bench.constants import FUSION_BENCH_VERSION
|
|
7
12
|
from fusion_bench.utils import import_object, instantiate
|
|
13
|
+
from fusion_bench.utils.instantiate_utils import set_print_function_call
|
|
8
14
|
|
|
9
15
|
log = logging.getLogger(__name__)
|
|
10
16
|
|
|
17
|
+
__all__ = [
|
|
18
|
+
"YAMLSerializationMixin",
|
|
19
|
+
"auto_register_config",
|
|
20
|
+
"BaseYAMLSerializable",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def auto_register_config(cls):
|
|
25
|
+
"""
|
|
26
|
+
Decorator to automatically register __init__ parameters in _config_mapping.
|
|
27
|
+
|
|
28
|
+
This decorator enhances classes that inherit from YAMLSerializationMixin by
|
|
29
|
+
automatically mapping constructor parameters to configuration keys and
|
|
30
|
+
dynamically setting instance attributes based on provided arguments.
|
|
31
|
+
|
|
32
|
+
The decorator performs the following operations:
|
|
33
|
+
1. Inspects the class's __init__ method signature
|
|
34
|
+
2. Automatically populates the _config_mapping dictionary with parameter names
|
|
35
|
+
3. Wraps the __init__ method to handle both positional and keyword arguments
|
|
36
|
+
4. Sets instance attributes for all constructor parameters
|
|
37
|
+
5. Applies default values when parameters are not provided
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
cls (YAMLSerializationMixin): The class to be decorated. Must inherit from
|
|
41
|
+
YAMLSerializationMixin to ensure proper serialization capabilities.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
YAMLSerializationMixin: The decorated class with enhanced auto-registration
|
|
45
|
+
functionality and modified __init__ behavior.
|
|
46
|
+
|
|
47
|
+
Behavior:
|
|
48
|
+
- **Parameter Registration**: All non-variadic parameters (excluding *args, **kwargs)
|
|
49
|
+
from the __init__ method are automatically added to _config_mapping
|
|
50
|
+
- **Positional Arguments**: Handled in order and mapped to corresponding parameter names
|
|
51
|
+
- **Keyword Arguments**: Processed after positional arguments, overriding any conflicts
|
|
52
|
+
- **Default Values**: Applied when parameters are not provided via arguments
|
|
53
|
+
- **Attribute Setting**: All parameters become instance attributes accessible via dot notation
|
|
54
|
+
|
|
55
|
+
Example:
|
|
56
|
+
```python
|
|
57
|
+
@auto_register_config
|
|
58
|
+
class MyAlgorithm(BaseYAMLSerializable):
|
|
59
|
+
def __init__(self, learning_rate: float = 0.001, batch_size: int = 32, model_name: str = "default"):
|
|
60
|
+
super().__init__()
|
|
61
|
+
|
|
62
|
+
# All instantiation methods work automatically:
|
|
63
|
+
algo1 = MyAlgorithm(0.01, 64) # positional args
|
|
64
|
+
algo2 = MyAlgorithm(learning_rate=0.01, model_name="bert") # keyword args
|
|
65
|
+
algo3 = MyAlgorithm(0.01, batch_size=128, model_name="gpt") # mixed args
|
|
66
|
+
|
|
67
|
+
# Attributes are automatically set and can be serialized:
|
|
68
|
+
print(algo1.learning_rate) # 0.01
|
|
69
|
+
print(algo1.batch_size) # 64
|
|
70
|
+
print(algo1.model_name) # "default" (from default value)
|
|
71
|
+
|
|
72
|
+
config = algo1.config
|
|
73
|
+
# DictConfig({'_target_': 'MyAlgorithm', 'learning_rate': 0.01, 'batch_size': 64, 'model_name': 'default'})
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
Note:
|
|
77
|
+
- The decorator wraps the original __init__ method while preserving its signature for IDE support
|
|
78
|
+
- Parameters with *args or **kwargs signatures are ignored during registration
|
|
79
|
+
- The attributes are auto-registered, then the original __init__ method is called,
|
|
80
|
+
- Type hints, method name, and other metadata are preserved using functools.wraps
|
|
81
|
+
- This decorator is designed to work seamlessly with the YAML serialization system
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
AttributeError: If the class does not have the required _config_mapping attribute
|
|
85
|
+
infrastructure (should inherit from YAMLSerializationMixin)
|
|
86
|
+
"""
|
|
87
|
+
original_init = cls.__init__
|
|
88
|
+
sig = inspect.signature(original_init)
|
|
89
|
+
|
|
90
|
+
# Auto-register parameters in _config_mapping
|
|
91
|
+
if not "_config_mapping" in cls.__dict__:
|
|
92
|
+
cls._config_mapping = deepcopy(getattr(cls, "_config_mapping", {}))
|
|
93
|
+
for param_name in list(sig.parameters.keys())[1:]: # Skip 'self'
|
|
94
|
+
if sig.parameters[param_name].kind not in [
|
|
95
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
96
|
+
_ParameterKind.VAR_KEYWORD,
|
|
97
|
+
]:
|
|
98
|
+
cls._config_mapping[param_name] = param_name
|
|
99
|
+
|
|
100
|
+
def __init__(self, *args, **kwargs):
|
|
101
|
+
# auto-register the attributes based on the signature
|
|
102
|
+
sig = inspect.signature(original_init)
|
|
103
|
+
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
|
|
104
|
+
|
|
105
|
+
# Handle positional arguments
|
|
106
|
+
for i, arg_value in enumerate(args):
|
|
107
|
+
if i < len(param_names):
|
|
108
|
+
param_name = param_names[i]
|
|
109
|
+
if sig.parameters[param_name].kind not in [
|
|
110
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
111
|
+
_ParameterKind.VAR_KEYWORD,
|
|
112
|
+
]:
|
|
113
|
+
setattr(self, param_name, arg_value)
|
|
114
|
+
|
|
115
|
+
# Handle keyword arguments and defaults
|
|
116
|
+
for param_name in param_names:
|
|
117
|
+
if sig.parameters[param_name].kind not in [
|
|
118
|
+
_ParameterKind.VAR_POSITIONAL,
|
|
119
|
+
_ParameterKind.VAR_KEYWORD,
|
|
120
|
+
]:
|
|
121
|
+
# Skip if already set by positional argument
|
|
122
|
+
param_index = param_names.index(param_name)
|
|
123
|
+
if param_index >= 0 and param_index < len(args):
|
|
124
|
+
continue
|
|
125
|
+
|
|
126
|
+
if param_name in kwargs:
|
|
127
|
+
setattr(self, param_name, kwargs[param_name])
|
|
128
|
+
else:
|
|
129
|
+
# Set default value if available and attribute doesn't exist
|
|
130
|
+
default_value = sig.parameters[param_name].default
|
|
131
|
+
if default_value is not Parameter.empty:
|
|
132
|
+
setattr(self, param_name, default_value)
|
|
133
|
+
|
|
134
|
+
# Call the original __init__
|
|
135
|
+
result = original_init(self, *args, **kwargs)
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
# Replace the original __init__ method while preserving its signature
|
|
139
|
+
cls.__init__ = __init__
|
|
140
|
+
return cls
|
|
141
|
+
|
|
11
142
|
|
|
12
143
|
class YAMLSerializationMixin:
|
|
13
|
-
_recursive_: bool = False
|
|
14
144
|
_config_key: Optional[str] = None
|
|
15
|
-
_config_mapping: Dict[str, str] = {
|
|
16
|
-
"_recursive_": "_recursive_",
|
|
17
|
-
}
|
|
145
|
+
_config_mapping: Dict[str, str] = {}
|
|
18
146
|
R"""
|
|
19
147
|
`_config_mapping` is a dictionary mapping the attribute names of the class to the config option names. This is used to convert the class to a DictConfig.
|
|
20
148
|
|
|
@@ -47,46 +175,50 @@ class YAMLSerializationMixin:
|
|
|
47
175
|
By default, the `_target_` key is set to the class name as `type(self).__name__`.
|
|
48
176
|
"""
|
|
49
177
|
|
|
50
|
-
def __init__(
|
|
51
|
-
self,
|
|
52
|
-
_recursive_: bool = False,
|
|
53
|
-
**kwargs,
|
|
54
|
-
) -> None:
|
|
55
|
-
self._recursive_ = _recursive_
|
|
178
|
+
def __init__(self, **kwargs) -> None:
|
|
56
179
|
for key, value in kwargs.items():
|
|
57
180
|
log.warning(f"Unused argument: {key}={value}")
|
|
58
181
|
|
|
59
182
|
@property
|
|
60
|
-
def config(self):
|
|
183
|
+
def config(self) -> DictConfig:
|
|
61
184
|
R"""
|
|
62
185
|
Returns the configuration of the model pool as a DictConfig.
|
|
63
186
|
|
|
64
|
-
This property
|
|
65
|
-
|
|
66
|
-
serialization or other purposes.
|
|
187
|
+
This property converts the model pool instance into a dictionary
|
|
188
|
+
configuration, which can be used for serialization or other purposes.
|
|
67
189
|
|
|
68
190
|
Example:
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
191
|
+
|
|
192
|
+
```python
|
|
193
|
+
model = SomeModelFusionAlgorithm(hyper_param_1=1, hyper_param_2=2)
|
|
194
|
+
config = model.config
|
|
195
|
+
print(config)
|
|
196
|
+
# DictConfig({'_target_': 'SomeModelFusionAlgorithm', 'hyper_param_1': 1, 'hyper_param_2': 2})
|
|
197
|
+
```
|
|
73
198
|
|
|
74
199
|
This is useful for serializing the object to a YAML file or for debugging.
|
|
75
200
|
|
|
76
201
|
Returns:
|
|
77
202
|
DictConfig: The configuration of the model pool.
|
|
78
203
|
"""
|
|
79
|
-
|
|
204
|
+
config = {"_target_": f"{type(self).__module__}.{type(self).__qualname__}"}
|
|
205
|
+
for attr, key in self._config_mapping.items():
|
|
206
|
+
if hasattr(self, attr):
|
|
207
|
+
config[key] = getattr(self, attr)
|
|
80
208
|
|
|
81
|
-
|
|
209
|
+
try:
|
|
210
|
+
return OmegaConf.create(config)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
return OmegaConf.create(config, flags={"allow_objects": True})
|
|
213
|
+
|
|
214
|
+
def to_yaml(self, path: Union[str, Path], resolve: bool = True):
|
|
82
215
|
"""
|
|
83
216
|
Save the model pool to a YAML file.
|
|
84
217
|
|
|
85
218
|
Args:
|
|
86
219
|
path (Union[str, Path]): The path to save the model pool to.
|
|
87
220
|
"""
|
|
88
|
-
config =
|
|
89
|
-
OmegaConf.save(config, path, resolve=True)
|
|
221
|
+
OmegaConf.save(self.config, path, resolve=resolve)
|
|
90
222
|
|
|
91
223
|
@classmethod
|
|
92
224
|
def from_yaml(cls, path: Union[str, Path]):
|
|
@@ -108,41 +240,126 @@ class YAMLSerializationMixin:
|
|
|
108
240
|
f"The class {target_cls.__name__} is not the same as the class {cls.__name__}. "
|
|
109
241
|
f"Instantiating the class {target_cls.__name__} instead."
|
|
110
242
|
)
|
|
111
|
-
|
|
112
|
-
config
|
|
113
|
-
_recursive_=(
|
|
114
|
-
cls._recursive_
|
|
115
|
-
if config.get("_recursive_") is None
|
|
116
|
-
else config.get("_recursive_")
|
|
117
|
-
),
|
|
118
|
-
)
|
|
243
|
+
with set_print_function_call(False):
|
|
244
|
+
return instantiate(config)
|
|
119
245
|
|
|
120
|
-
def
|
|
246
|
+
def register_parameter_to_config(
|
|
247
|
+
self,
|
|
248
|
+
attr_name: str,
|
|
249
|
+
param_name: str,
|
|
250
|
+
value,
|
|
251
|
+
):
|
|
121
252
|
"""
|
|
122
|
-
|
|
253
|
+
Set an attribute value and register its config mapping.
|
|
123
254
|
|
|
124
|
-
|
|
125
|
-
|
|
255
|
+
This method allows dynamic setting of object attributes while simultaneously
|
|
256
|
+
updating the configuration mapping that defines how the attribute should
|
|
257
|
+
be serialized in the configuration output.
|
|
258
|
+
|
|
259
|
+
Args:
|
|
260
|
+
attr_name (str): The name of the attribute to set on this object.
|
|
261
|
+
arg_name (str): The corresponding parameter name to use in the config
|
|
262
|
+
serialization. This is how the attribute will appear in YAML output.
|
|
263
|
+
value: The value to assign to the attribute.
|
|
264
|
+
|
|
265
|
+
Example:
|
|
266
|
+
```python
|
|
267
|
+
model = BaseYAMLSerializable()
|
|
268
|
+
model.set_option("learning_rate", "lr", 0.001)
|
|
269
|
+
|
|
270
|
+
# This sets model.learning_rate = 0.001
|
|
271
|
+
# and maps it to "lr" in the config output
|
|
272
|
+
config = model.config
|
|
273
|
+
# config will contain: {"lr": 0.001, ...}
|
|
274
|
+
```
|
|
126
275
|
"""
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
276
|
+
setattr(self, attr_name, value)
|
|
277
|
+
self._config_mapping[attr_name] = param_name
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@auto_register_config
|
|
281
|
+
class BaseYAMLSerializable(YAMLSerializationMixin):
|
|
282
|
+
"""
|
|
283
|
+
A base class for YAML-serializable classes with enhanced metadata support.
|
|
284
|
+
|
|
285
|
+
This class extends `YAMLSerializationMixin` to provide additional metadata
|
|
286
|
+
fields commonly used in FusionBench classes, including usage information
|
|
287
|
+
and version tracking. It serves as a foundation for all serializable
|
|
288
|
+
model components in the framework.
|
|
289
|
+
|
|
290
|
+
The class automatically handles serialization of usage and version metadata
|
|
291
|
+
alongside the standard configuration parameters, making it easier to track
|
|
292
|
+
model provenance and intended usage patterns.
|
|
132
293
|
|
|
294
|
+
Attributes:
|
|
295
|
+
_usage_ (Optional[str]): Description of the model's intended usage or purpose.
|
|
296
|
+
_version_ (Optional[str]): Version information for the model or configuration.
|
|
133
297
|
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
298
|
+
Example:
|
|
299
|
+
```python
|
|
300
|
+
class MyAlgorithm(BaseYAMLSerializable):
|
|
301
|
+
_config_mapping = BaseYAMLSerializable._config_mapping | {
|
|
302
|
+
"model_name": "model_name",
|
|
303
|
+
"num_layers": "num_layers",
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
def __init__(self, _usage_: str = None, _version_: str = None):
|
|
307
|
+
super().__init__(_usage_=_usage_, _version_=_version_)
|
|
308
|
+
|
|
309
|
+
# Usage with metadata
|
|
310
|
+
model = MyAlgorithm(
|
|
311
|
+
_usage_="Text classification fine-tuning",
|
|
312
|
+
_version_="1.0.0"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
# Serialization includes metadata
|
|
316
|
+
config = model.config
|
|
317
|
+
# DictConfig({
|
|
318
|
+
# '_target_': 'MyModel',
|
|
319
|
+
# '_usage_': 'Text classification fine-tuning',
|
|
320
|
+
# '_version_': '1.0.0'
|
|
321
|
+
# })
|
|
322
|
+
```
|
|
323
|
+
|
|
324
|
+
Note:
|
|
325
|
+
The underscore prefix in `_usage_` and `_version_` follows the convention
|
|
326
|
+
for metadata fields that are not core model parameters but provide
|
|
327
|
+
important contextual information for model management and tracking.
|
|
328
|
+
"""
|
|
139
329
|
|
|
140
330
|
def __init__(
|
|
141
331
|
self,
|
|
332
|
+
_recursive_: bool = False,
|
|
142
333
|
_usage_: Optional[str] = None,
|
|
143
|
-
_version_: Optional[str] =
|
|
334
|
+
_version_: Optional[str] = FUSION_BENCH_VERSION,
|
|
144
335
|
**kwargs,
|
|
145
336
|
):
|
|
337
|
+
"""
|
|
338
|
+
Initialize a base YAML-serializable model with metadata support.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
_usage_ (Optional[str], optional): Description of the model's intended
|
|
342
|
+
usage or purpose. This can include information about the training
|
|
343
|
+
domain, expected input types, or specific use cases. Defaults to None.
|
|
344
|
+
_version_ (Optional[str], optional): Version information for the model
|
|
345
|
+
or configuration. Can be used to track model iterations, dataset
|
|
346
|
+
versions, or compatibility information. Defaults to None.
|
|
347
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
348
|
+
Unused arguments will trigger warnings via the parent's initialization.
|
|
349
|
+
|
|
350
|
+
Example:
|
|
351
|
+
```python
|
|
352
|
+
model = BaseYAMLSerializable(
|
|
353
|
+
_usage_="Image classification on CIFAR-10",
|
|
354
|
+
_version_="2.1.0"
|
|
355
|
+
)
|
|
356
|
+
```
|
|
357
|
+
"""
|
|
146
358
|
super().__init__(**kwargs)
|
|
147
|
-
|
|
148
|
-
|
|
359
|
+
if _version_ != FUSION_BENCH_VERSION:
|
|
360
|
+
log.warning(
|
|
361
|
+
f"Current fusion-bench version is {FUSION_BENCH_VERSION}, but the serialized version is {_version_}. "
|
|
362
|
+
"Attempting to use current version."
|
|
363
|
+
)
|
|
364
|
+
# override _version_ with current fusion-bench version
|
|
365
|
+
self._version_ = FUSION_BENCH_VERSION
|
|
@@ -17,7 +17,7 @@ _import_structure = {
|
|
|
17
17
|
"HuggingFaceGPT2ClassificationPool",
|
|
18
18
|
"GPT2ForSequenceClassificationPool",
|
|
19
19
|
],
|
|
20
|
-
"seq_classification_lm": ["
|
|
20
|
+
"seq_classification_lm": ["SequenceClassificationModelPool"],
|
|
21
21
|
}
|
|
22
22
|
|
|
23
23
|
|
|
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
|
|
|
34
34
|
from .openclip_vision import OpenCLIPVisionModelPool
|
|
35
35
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
36
36
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
37
|
-
from .seq_classification_lm import
|
|
37
|
+
from .seq_classification_lm import SequenceClassificationModelPool
|
|
38
38
|
|
|
39
39
|
else:
|
|
40
40
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Dict, List, Optional, Union
|
|
3
|
+
from typing import Dict, Generator, List, Optional, Tuple, Union
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
from omegaconf import DictConfig
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
-
from fusion_bench.mixins import
|
|
10
|
+
from fusion_bench.mixins import BaseYAMLSerializable, HydraConfigMixin
|
|
11
11
|
from fusion_bench.utils import instantiate, timeit_context
|
|
12
12
|
|
|
13
13
|
__all__ = ["BaseModelPool"]
|
|
@@ -15,7 +15,10 @@ __all__ = ["BaseModelPool"]
|
|
|
15
15
|
log = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class BaseModelPool(
|
|
18
|
+
class BaseModelPool(
|
|
19
|
+
HydraConfigMixin,
|
|
20
|
+
BaseYAMLSerializable,
|
|
21
|
+
):
|
|
19
22
|
"""
|
|
20
23
|
A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
|
|
21
24
|
|
|
@@ -31,7 +34,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
31
34
|
_program = None
|
|
32
35
|
_config_key = "modelpool"
|
|
33
36
|
_models: Union[DictConfig, Dict[str, nn.Module]]
|
|
34
|
-
_config_mapping =
|
|
37
|
+
_config_mapping = BaseYAMLSerializable._config_mapping | {
|
|
35
38
|
"_models": "models",
|
|
36
39
|
"_train_datasets": "train_datasets",
|
|
37
40
|
"_val_datasets": "val_datasets",
|
|
@@ -56,7 +59,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
56
59
|
super().__init__(**kwargs)
|
|
57
60
|
|
|
58
61
|
@property
|
|
59
|
-
def has_pretrained(self):
|
|
62
|
+
def has_pretrained(self) -> bool:
|
|
60
63
|
"""
|
|
61
64
|
Check if the model pool contains a pretrained model.
|
|
62
65
|
|
|
@@ -125,7 +128,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
125
128
|
return len(self.model_names)
|
|
126
129
|
|
|
127
130
|
@staticmethod
|
|
128
|
-
def is_special_model(model_name: str):
|
|
131
|
+
def is_special_model(model_name: str) -> bool:
|
|
129
132
|
"""
|
|
130
133
|
Determine if a model is special based on its name.
|
|
131
134
|
|
|
@@ -152,6 +155,23 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
152
155
|
model_config = deepcopy(model_config)
|
|
153
156
|
return model_config
|
|
154
157
|
|
|
158
|
+
def get_model_path(self, model_name: str) -> str:
|
|
159
|
+
"""
|
|
160
|
+
Get the path for the specified model.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
model_name (str): The name of the model.
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
str: The path for the specified model.
|
|
167
|
+
"""
|
|
168
|
+
if isinstance(self._models[model_name], str):
|
|
169
|
+
return self._models[model_name]
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
"Model path is not a string. Try to override this method in derived modelpool class."
|
|
173
|
+
)
|
|
174
|
+
|
|
155
175
|
def load_model(
|
|
156
176
|
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
157
177
|
) -> nn.Module:
|
|
@@ -159,7 +179,7 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
159
179
|
Load a model from the pool based on the provided configuration.
|
|
160
180
|
|
|
161
181
|
Args:
|
|
162
|
-
|
|
182
|
+
model_name_or_config (Union[str, DictConfig]): The model name or configuration.
|
|
163
183
|
|
|
164
184
|
Returns:
|
|
165
185
|
nn.Module: The instantiated model.
|
|
@@ -201,11 +221,11 @@ class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
|
201
221
|
model = self.load_model(self.model_names[0], *args, **kwargs)
|
|
202
222
|
return model
|
|
203
223
|
|
|
204
|
-
def models(self):
|
|
224
|
+
def models(self) -> Generator[nn.Module, None, None]:
|
|
205
225
|
for model_name in self.model_names:
|
|
206
226
|
yield self.load_model(model_name)
|
|
207
227
|
|
|
208
|
-
def named_models(self):
|
|
228
|
+
def named_models(self) -> Generator[Tuple[str, nn.Module], None, None]:
|
|
209
229
|
for model_name in self.model_names:
|
|
210
230
|
yield model_name, self.load_model(model_name)
|
|
211
231
|
|
|
@@ -57,6 +57,15 @@ class CausalLMPool(BaseModelPool):
|
|
|
57
57
|
)
|
|
58
58
|
self.load_lazy = load_lazy
|
|
59
59
|
|
|
60
|
+
def get_model_path(self, model_name: str):
|
|
61
|
+
model_name_or_config = self._models[model_name]
|
|
62
|
+
if isinstance(model_name_or_config, str):
|
|
63
|
+
return model_name_or_config
|
|
64
|
+
elif isinstance(model_name_or_config, (DictConfig, dict)):
|
|
65
|
+
return model_name_or_config.get("pretrained_model_name_or_path")
|
|
66
|
+
else:
|
|
67
|
+
raise RuntimeError("Invalid model configuration")
|
|
68
|
+
|
|
60
69
|
@override
|
|
61
70
|
def load_model(
|
|
62
71
|
self,
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import Optional, Union
|
|
3
|
+
from typing import Literal, Optional, Union
|
|
4
4
|
|
|
5
5
|
from datasets import load_dataset
|
|
6
6
|
from lightning.fabric.utilities import rank_zero_only
|
|
@@ -11,6 +11,7 @@ from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
|
11
11
|
from typing_extensions import override
|
|
12
12
|
|
|
13
13
|
from fusion_bench.utils import instantiate, timeit_context
|
|
14
|
+
from fusion_bench.utils.modelscope import resolve_repo_path
|
|
14
15
|
|
|
15
16
|
from ..base_pool import BaseModelPool
|
|
16
17
|
|
|
@@ -25,25 +26,32 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
25
26
|
the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
|
|
26
27
|
"""
|
|
27
28
|
|
|
28
|
-
_config_mapping = BaseModelPool._config_mapping | {
|
|
29
|
+
_config_mapping = BaseModelPool._config_mapping | {
|
|
30
|
+
"_processor": "processor",
|
|
31
|
+
"_platform": "hf",
|
|
32
|
+
}
|
|
29
33
|
|
|
30
34
|
def __init__(
|
|
31
35
|
self,
|
|
32
36
|
models: DictConfig,
|
|
33
37
|
*,
|
|
34
38
|
processor: Optional[DictConfig] = None,
|
|
39
|
+
platform: Literal["hf", "huggingface", "modelscope"] = "hf",
|
|
35
40
|
**kwargs,
|
|
36
41
|
):
|
|
37
42
|
super().__init__(models, **kwargs)
|
|
38
|
-
|
|
39
43
|
self._processor = processor
|
|
44
|
+
self._platform = platform
|
|
40
45
|
|
|
41
46
|
def load_processor(self, *args, **kwargs) -> CLIPProcessor:
|
|
42
47
|
assert self._processor is not None, "Processor is not defined in the config"
|
|
43
48
|
if isinstance(self._processor, str):
|
|
44
49
|
if rank_zero_only.rank == 0:
|
|
45
50
|
log.info(f"Loading `transformers.CLIPProcessor`: {self._processor}")
|
|
46
|
-
|
|
51
|
+
repo_path = resolve_repo_path(
|
|
52
|
+
repo_id=self._processor, repo_type="model", platform=self._platform
|
|
53
|
+
)
|
|
54
|
+
processor = CLIPProcessor.from_pretrained(repo_path, *args, **kwargs)
|
|
47
55
|
else:
|
|
48
56
|
processor = instantiate(self._processor, *args, **kwargs)
|
|
49
57
|
return processor
|
|
@@ -54,7 +62,10 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
54
62
|
if isinstance(model_config, str):
|
|
55
63
|
if rank_zero_only.rank == 0:
|
|
56
64
|
log.info(f"Loading `transformers.CLIPModel`: {model_config}")
|
|
57
|
-
|
|
65
|
+
repo_path = resolve_repo_path(
|
|
66
|
+
repo_id=model_config, repo_type="model", platform=self._platform
|
|
67
|
+
)
|
|
68
|
+
clip_model = CLIPModel.from_pretrained(repo_path, *args, **kwargs)
|
|
58
69
|
return clip_model
|
|
59
70
|
else:
|
|
60
71
|
assert isinstance(
|
|
@@ -107,14 +118,17 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
107
118
|
if isinstance(model, str):
|
|
108
119
|
if rank_zero_only.rank == 0:
|
|
109
120
|
log.info(f"Loading `transformers.CLIPVisionModel`: {model}")
|
|
110
|
-
|
|
121
|
+
repo_path = resolve_repo_path(
|
|
122
|
+
model, repo_type="model", platform=self._platform
|
|
123
|
+
)
|
|
124
|
+
return CLIPVisionModel.from_pretrained(repo_path, *args, **kwargs)
|
|
111
125
|
if isinstance(model, nn.Module):
|
|
112
126
|
if rank_zero_only.rank == 0:
|
|
113
127
|
log.info(f"Returning existing model: {model}")
|
|
114
128
|
return model
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
129
|
+
else:
|
|
130
|
+
# If the model is not a string, we use the default load_model method
|
|
131
|
+
return super().load_model(model_name_or_config, *args, **kwargs)
|
|
118
132
|
|
|
119
133
|
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
120
134
|
dataset_config = self._train_datasets[dataset_name]
|
|
@@ -123,7 +137,7 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
123
137
|
log.info(
|
|
124
138
|
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
125
139
|
)
|
|
126
|
-
dataset =
|
|
140
|
+
dataset = self._load_dataset(dataset_config, split="train")
|
|
127
141
|
else:
|
|
128
142
|
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
129
143
|
return dataset
|
|
@@ -135,7 +149,7 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
135
149
|
log.info(
|
|
136
150
|
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
137
151
|
)
|
|
138
|
-
dataset =
|
|
152
|
+
dataset = self._load_dataset(dataset_config, split="validation")
|
|
139
153
|
else:
|
|
140
154
|
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
141
155
|
return dataset
|
|
@@ -147,7 +161,24 @@ class CLIPVisionModelPool(BaseModelPool):
|
|
|
147
161
|
log.info(
|
|
148
162
|
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
149
163
|
)
|
|
150
|
-
dataset =
|
|
164
|
+
dataset = self._load_dataset(dataset_config, split="test")
|
|
151
165
|
else:
|
|
152
166
|
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
153
167
|
return dataset
|
|
168
|
+
|
|
169
|
+
def _load_dataset(self, name: str, split: str):
|
|
170
|
+
"""
|
|
171
|
+
Load a dataset by its name and split.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
dataset_name (str): The name of the dataset.
|
|
175
|
+
split (str): The split of the dataset to load (e.g., "train", "validation", "test").
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
Dataset: The loaded dataset.
|
|
179
|
+
"""
|
|
180
|
+
datset_dir = resolve_repo_path(
|
|
181
|
+
name, repo_type="dataset", platform=self._platform
|
|
182
|
+
)
|
|
183
|
+
dataset = load_dataset(datset_dir, split=split)
|
|
184
|
+
return dataset
|
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
from .reward_model import create_reward_model_from_pretrained
|
|
2
|
-
from .seq_classification_lm import
|
|
2
|
+
from .seq_classification_lm import SequenceClassificationModelPool
|
fusion_bench/models/__init__.py
CHANGED