fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -22,6 +22,7 @@ import torch
|
|
|
22
22
|
from torch import Tensor, nn
|
|
23
23
|
from torch.func import functional_call
|
|
24
24
|
|
|
25
|
+
from fusion_bench.models.utils import StateDictType, del_attr, get_attr, set_attr
|
|
25
26
|
from fusion_bench.utils.type import StateDictType, TorchModelType
|
|
26
27
|
|
|
27
28
|
log = logging.getLogger(__name__)
|
|
@@ -29,77 +30,7 @@ log = logging.getLogger(__name__)
|
|
|
29
30
|
__all__ = ["get_task_wise_weights", "fuse_weights", "TaskWiseMergedModel"]
|
|
30
31
|
|
|
31
32
|
|
|
32
|
-
def
|
|
33
|
-
"""
|
|
34
|
-
Deletes an attribute from an object recursively.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
obj (object): Object to delete attribute from.
|
|
38
|
-
names (list): List of attribute names to delete recursively.
|
|
39
|
-
"""
|
|
40
|
-
if len(names) == 1:
|
|
41
|
-
delattr(obj, names[0])
|
|
42
|
-
else:
|
|
43
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
def set_attr(obj, names: List[str], val):
|
|
47
|
-
"""
|
|
48
|
-
Sets an attribute of an object recursively.
|
|
49
|
-
|
|
50
|
-
Args:
|
|
51
|
-
obj (object): Object to set attribute of.
|
|
52
|
-
names (list): List of attribute names to set recursively.
|
|
53
|
-
val (object): Value to set the attribute to.
|
|
54
|
-
"""
|
|
55
|
-
if len(names) == 1:
|
|
56
|
-
setattr(obj, names[0], val)
|
|
57
|
-
else:
|
|
58
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def get_attr(obj, names: List[str]):
|
|
62
|
-
"""
|
|
63
|
-
Gets an attribute of an object recursively.
|
|
64
|
-
|
|
65
|
-
Args:
|
|
66
|
-
obj (object): Object to get attribute of.
|
|
67
|
-
names (list): List of attribute names to get recursively.
|
|
68
|
-
|
|
69
|
-
Returns:
|
|
70
|
-
object: The attribute of the object.
|
|
71
|
-
"""
|
|
72
|
-
if len(names) == 1:
|
|
73
|
-
return getattr(obj, names[0])
|
|
74
|
-
else:
|
|
75
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
|
|
79
|
-
"""
|
|
80
|
-
Checks that the parameter names of the given checkpoints match.
|
|
81
|
-
|
|
82
|
-
Args:
|
|
83
|
-
checkpoints (List[Dict[str, float]]): A list of checkpoints, where each checkpoint is a dictionary of parameter names and their corresponding values.
|
|
84
|
-
|
|
85
|
-
Raises:
|
|
86
|
-
ValueError: If the number of checkpoints is less than 2 or if the parameter names of any two checkpoints differ.
|
|
87
|
-
|
|
88
|
-
"""
|
|
89
|
-
parameter_names = set(checkpoints[0].keys())
|
|
90
|
-
|
|
91
|
-
if len(checkpoints) >= 2:
|
|
92
|
-
# raise ValueError("Number of models is less than 2.")
|
|
93
|
-
for checkpoint in checkpoints[1:]:
|
|
94
|
-
current_parameterNames = set(checkpoint.keys())
|
|
95
|
-
if current_parameterNames != parameter_names:
|
|
96
|
-
raise ValueError(
|
|
97
|
-
"Differing parameter names in models. "
|
|
98
|
-
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
def get_task_wise_weights(num_models: int, init_values: float = None):
|
|
33
|
+
def get_task_wise_weights(num_models: int, init_values: float = None) -> Tensor:
|
|
103
34
|
"""
|
|
104
35
|
This function generates a tensor of weights for each model.
|
|
105
36
|
|
|
@@ -116,7 +47,7 @@ def get_task_wise_weights(num_models: int, init_values: float = None):
|
|
|
116
47
|
return torch.full((num_models,), init_values, dtype=torch.float32)
|
|
117
48
|
|
|
118
49
|
|
|
119
|
-
def _fuse_weights(task_wise_weight: Tensor, tensors: List[Tensor]):
|
|
50
|
+
def _fuse_weights(task_wise_weight: Tensor, tensors: List[Tensor]) -> Tensor:
|
|
120
51
|
"""
|
|
121
52
|
This function fuses the weights of the models.
|
|
122
53
|
|
|
@@ -158,6 +89,100 @@ def fuse_weights(
|
|
|
158
89
|
|
|
159
90
|
|
|
160
91
|
class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
92
|
+
"""
|
|
93
|
+
A PyTorch module that dynamically merges multiple fine-tuned models using learnable task-wise weights.
|
|
94
|
+
|
|
95
|
+
This class implements a sophisticated model fusion approach where multiple task-specific models
|
|
96
|
+
are combined with a pretrained base model using learnable weights. The fusion is performed
|
|
97
|
+
using task vectors (differences between fine-tuned and pretrained models) that are weighted
|
|
98
|
+
and added to the base model's parameters.
|
|
99
|
+
|
|
100
|
+
The key innovation is that the merging weights are learnable parameters that can be optimized
|
|
101
|
+
during training, allowing the model to automatically learn the optimal combination of different
|
|
102
|
+
task-specific knowledge.
|
|
103
|
+
|
|
104
|
+
Architecture:
|
|
105
|
+
- Base pretrained model (frozen)
|
|
106
|
+
- Multiple task vectors (differences from pretrained model, frozen)
|
|
107
|
+
- Learnable task-wise weights (trainable parameters)
|
|
108
|
+
- Dynamic merging during forward pass
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
task_wise_weight (Tensor): Initial weights for each task model. Shape: (num_models,).
|
|
112
|
+
These become learnable parameters that control the contribution of each task vector.
|
|
113
|
+
pretrained_model (TorchModelType): The base pretrained model that serves as the foundation.
|
|
114
|
+
This model is frozen and used as the starting point for merging.
|
|
115
|
+
finetuned_models (List[TorchModelType]): List of fine-tuned models for different tasks.
|
|
116
|
+
These are converted to task vectors (differences from pretrained model) and frozen.
|
|
117
|
+
clamp_weights (bool, optional): Whether to clamp merge weights to [0, 1] range.
|
|
118
|
+
Defaults to True. When True, ensures weights are non-negative and bounded.
|
|
119
|
+
tie_weights (bool, optional): Whether to tie weights during functional call.
|
|
120
|
+
Defaults to False. Used in the underlying PyTorch functional_call.
|
|
121
|
+
strict (bool, optional): Whether to enforce strict parameter matching.
|
|
122
|
+
Defaults to True. Used in the underlying PyTorch functional_call.
|
|
123
|
+
task_vector_dtype (Optional[torch.dtype], optional): Data type for task vectors.
|
|
124
|
+
Defaults to None. Can be used to save memory (e.g., torch.float16).
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
merge_weight (nn.Parameter): Learnable weights for merging task vectors.
|
|
128
|
+
pretrained_model (TorchModelType): The frozen base model.
|
|
129
|
+
task_vectors (nn.ModuleList): List of frozen task vector models.
|
|
130
|
+
_merged_state_dict (StateDictType): Cached merged state dictionary.
|
|
131
|
+
|
|
132
|
+
Example:
|
|
133
|
+
```python
|
|
134
|
+
import torch
|
|
135
|
+
import torch.nn as nn
|
|
136
|
+
|
|
137
|
+
# Create example models
|
|
138
|
+
pretrained_model = nn.Linear(10, 5)
|
|
139
|
+
finetuned_model1 = nn.Linear(10, 5) # Fine-tuned on task 1
|
|
140
|
+
finetuned_model2 = nn.Linear(10, 5) # Fine-tuned on task 2
|
|
141
|
+
|
|
142
|
+
# Initialize task-wise weights
|
|
143
|
+
task_weights = torch.tensor([0.3, 0.7]) # Initial weights for 2 tasks
|
|
144
|
+
|
|
145
|
+
# Create merged model
|
|
146
|
+
merged_model = TaskWiseMergedModel(
|
|
147
|
+
task_wise_weight=task_weights,
|
|
148
|
+
pretrained_model=pretrained_model,
|
|
149
|
+
finetuned_models=[finetuned_model1, finetuned_model2],
|
|
150
|
+
clamp_weights=True
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Use like a regular PyTorch model
|
|
154
|
+
x = torch.randn(32, 10)
|
|
155
|
+
output = merged_model(x)
|
|
156
|
+
|
|
157
|
+
# Train the merge weights
|
|
158
|
+
optimizer = torch.optim.Adam(merged_model.parameters())
|
|
159
|
+
loss = some_loss_function(output, targets)
|
|
160
|
+
loss.backward()
|
|
161
|
+
optimizer.step()
|
|
162
|
+
|
|
163
|
+
# Get the final merged model
|
|
164
|
+
final_model = merged_model.merge_and_unload()
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
Training Workflow:
|
|
168
|
+
1. **Initialization**: Task vectors are computed as differences from pretrained model
|
|
169
|
+
2. **Forward Pass**: Weights are dynamically merged based on current merge_weight values
|
|
170
|
+
3. **Loss Computation**: Standard loss computation on model outputs
|
|
171
|
+
4. **Backpropagation**: Gradients flow through merge_weight parameters
|
|
172
|
+
5. **Optimization**: merge_weight parameters are updated to improve performance
|
|
173
|
+
|
|
174
|
+
Memory Efficiency:
|
|
175
|
+
- Task vectors can use lower precision (task_vector_dtype)
|
|
176
|
+
- Base model and task vectors are frozen (no gradient computation)
|
|
177
|
+
- Only merge weights require gradients
|
|
178
|
+
|
|
179
|
+
Note:
|
|
180
|
+
- The pretrained model and task vectors are frozen during training
|
|
181
|
+
- Only the merge weights (task_wise_weight) are trainable parameters
|
|
182
|
+
- Task vectors represent the difference between fine-tuned and pretrained models
|
|
183
|
+
- The merged state dict is cached and recomputed when merge weights change
|
|
184
|
+
"""
|
|
185
|
+
|
|
161
186
|
_merged_state_dict: StateDictType = None
|
|
162
187
|
|
|
163
188
|
def __init__(
|
|
@@ -170,6 +195,32 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
170
195
|
strict: bool = True,
|
|
171
196
|
task_vector_dtype: Optional[torch.dtype] = None,
|
|
172
197
|
):
|
|
198
|
+
"""
|
|
199
|
+
Initialize the TaskWiseMergedModel.
|
|
200
|
+
|
|
201
|
+
This constructor sets up the model by:
|
|
202
|
+
1. Converting fine-tuned models to task vectors (differences from pretrained)
|
|
203
|
+
2. Freezing the pretrained model and task vectors
|
|
204
|
+
3. Setting up learnable merge weights as parameters
|
|
205
|
+
4. Configuring merging behavior options
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
task_wise_weight (Tensor): Initial weights for each task model. Shape: (num_models,).
|
|
209
|
+
These values become the starting point for learnable parameters.
|
|
210
|
+
pretrained_model (TorchModelType): The base pretrained model.
|
|
211
|
+
Will be frozen and used as the foundation for merging.
|
|
212
|
+
finetuned_models (List[TorchModelType]): List of fine-tuned models.
|
|
213
|
+
Must have the same architecture as pretrained_model.
|
|
214
|
+
clamp_weights (bool, optional): Whether to clamp weights to [0, 1]. Defaults to True.
|
|
215
|
+
tie_weights (bool, optional): Whether to tie weights in functional_call. Defaults to False.
|
|
216
|
+
strict (bool, optional): Whether to use strict parameter matching. Defaults to True.
|
|
217
|
+
task_vector_dtype (Optional[torch.dtype], optional): Data type for task vectors.
|
|
218
|
+
Defaults to None (same as original models).
|
|
219
|
+
|
|
220
|
+
Raises:
|
|
221
|
+
ValueError: If the number of task_wise_weights doesn't match the number of fine-tuned models.
|
|
222
|
+
RuntimeError: If models have incompatible architectures.
|
|
223
|
+
"""
|
|
173
224
|
super().__init__()
|
|
174
225
|
self.clamp_weights = clamp_weights
|
|
175
226
|
self.tie_weights = tie_weights
|
|
@@ -196,6 +247,24 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
196
247
|
|
|
197
248
|
@property
|
|
198
249
|
def forward_model(self):
|
|
250
|
+
"""
|
|
251
|
+
Get a functional model with merged parameters.
|
|
252
|
+
|
|
253
|
+
Returns a partial function that applies the pretrained model with the current
|
|
254
|
+
merged state dictionary. This allows for efficient forward passes without
|
|
255
|
+
modifying the original model's parameters.
|
|
256
|
+
|
|
257
|
+
Returns:
|
|
258
|
+
Callable: A partial function that can be called with (args, kwargs) to
|
|
259
|
+
perform forward pass with merged parameters.
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
```python
|
|
263
|
+
# Internal usage during forward pass
|
|
264
|
+
forward_fn = merged_model.forward_model
|
|
265
|
+
output = forward_fn(args=(x,), kwargs={})
|
|
266
|
+
```
|
|
267
|
+
"""
|
|
199
268
|
return functools.partial(
|
|
200
269
|
functional_call,
|
|
201
270
|
self.pretrained_model,
|
|
@@ -205,6 +274,43 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
205
274
|
)
|
|
206
275
|
|
|
207
276
|
def merge_weights(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
277
|
+
"""
|
|
278
|
+
Merge task vectors with the pretrained model using current merge weights.
|
|
279
|
+
|
|
280
|
+
This method computes the merged model parameters by combining the pretrained
|
|
281
|
+
model with weighted task vectors. The resulting state dictionary represents
|
|
282
|
+
a model that incorporates knowledge from all task-specific models.
|
|
283
|
+
|
|
284
|
+
The merging formula for each parameter is:
|
|
285
|
+
merged_param = pretrained_param + Σ(weight_i * task_vector_i * mask_i)
|
|
286
|
+
|
|
287
|
+
Args:
|
|
288
|
+
task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
|
|
289
|
+
to selectively apply task vectors to specific parameters. Keys should
|
|
290
|
+
match parameter names, values should be tensors with the same shape
|
|
291
|
+
as the corresponding parameters. Defaults to None (no masking).
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
StateDictType: The merged state dictionary containing combined parameters.
|
|
295
|
+
|
|
296
|
+
Example:
|
|
297
|
+
```python
|
|
298
|
+
# Basic merging
|
|
299
|
+
merged_state = model.merge_weights()
|
|
300
|
+
|
|
301
|
+
# Merging with parameter-specific masks
|
|
302
|
+
masks = {
|
|
303
|
+
'layer1.weight': torch.ones_like(model.pretrained_model.layer1.weight),
|
|
304
|
+
'layer2.weight': torch.zeros_like(model.pretrained_model.layer2.weight),
|
|
305
|
+
}
|
|
306
|
+
masked_state = model.merge_weights(task_vector_mask=masks)
|
|
307
|
+
```
|
|
308
|
+
|
|
309
|
+
Note:
|
|
310
|
+
- If clamp_weights is True, merge weights are clamped to [0, 1] range
|
|
311
|
+
- The merged state dict is cached in _merged_state_dict
|
|
312
|
+
- Task vector masks allow fine-grained control over which parameters are affected
|
|
313
|
+
"""
|
|
208
314
|
if self.clamp_weights:
|
|
209
315
|
merge_weight = self.merge_weight.clamp(0, 1)
|
|
210
316
|
else:
|
|
@@ -222,11 +328,83 @@ class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
|
222
328
|
return state_dict
|
|
223
329
|
|
|
224
330
|
def merge_and_unload(self, task_vector_mask: Optional[Dict[str, Tensor]] = None):
|
|
331
|
+
"""
|
|
332
|
+
Merge models and return the final merged model.
|
|
333
|
+
|
|
334
|
+
This method performs the merging operation and then loads the merged parameters
|
|
335
|
+
into the pretrained model, returning a standard PyTorch model that can be used
|
|
336
|
+
independently of the TaskWiseMergedModel wrapper.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
task_vector_mask (Optional[Dict[str, Tensor]], optional): Optional masks
|
|
340
|
+
for selective parameter merging. Defaults to None.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
TorchModelType: The pretrained model with merged parameters loaded.
|
|
344
|
+
This is a standalone model that can be used without the wrapper.
|
|
345
|
+
|
|
346
|
+
Example:
|
|
347
|
+
```python
|
|
348
|
+
# Train the merged model
|
|
349
|
+
for epoch in range(num_epochs):
|
|
350
|
+
# ... training loop ...
|
|
351
|
+
pass
|
|
352
|
+
|
|
353
|
+
# Get the final merged model
|
|
354
|
+
final_model = merged_model.merge_and_unload()
|
|
355
|
+
|
|
356
|
+
# Save or use the final model
|
|
357
|
+
torch.save(final_model.state_dict(), 'merged_model.pth')
|
|
358
|
+
output = final_model(new_input)
|
|
359
|
+
```
|
|
360
|
+
|
|
361
|
+
Warning:
|
|
362
|
+
This method modifies the pretrained_model's parameters in-place.
|
|
363
|
+
The original pretrained model parameters will be lost.
|
|
364
|
+
"""
|
|
225
365
|
self.merge_weights(task_vector_mask=task_vector_mask)
|
|
226
366
|
self.pretrained_model.load_state_dict(self._merged_state_dict)
|
|
227
367
|
return self.pretrained_model
|
|
228
368
|
|
|
229
369
|
def forward(self, *args, **kwargs):
|
|
370
|
+
"""
|
|
371
|
+
Forward pass through the dynamically merged model.
|
|
372
|
+
|
|
373
|
+
This method performs the forward pass by first ensuring the model parameters
|
|
374
|
+
are merged according to the current merge weights, then applying the merged
|
|
375
|
+
model to the input data.
|
|
376
|
+
|
|
377
|
+
The forward pass involves:
|
|
378
|
+
1. Check if merged state dict is current (recompute if needed)
|
|
379
|
+
2. Apply the merged model to inputs using functional_call
|
|
380
|
+
3. Return the model outputs
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
*args: Positional arguments to pass to the underlying model.
|
|
384
|
+
**kwargs: Keyword arguments to pass to the underlying model.
|
|
385
|
+
|
|
386
|
+
Returns:
|
|
387
|
+
Any: The output of the merged model, typically torch.Tensor or tuple of tensors.
|
|
388
|
+
|
|
389
|
+
Example:
|
|
390
|
+
```python
|
|
391
|
+
# Single input
|
|
392
|
+
x = torch.randn(32, 784)
|
|
393
|
+
output = merged_model(x)
|
|
394
|
+
|
|
395
|
+
# Multiple inputs
|
|
396
|
+
x1, x2 = torch.randn(32, 784), torch.randn(32, 100)
|
|
397
|
+
output = merged_model(x1, x2)
|
|
398
|
+
|
|
399
|
+
# With keyword arguments
|
|
400
|
+
output = merged_model(input_ids=input_ids, attention_mask=attention_mask)
|
|
401
|
+
```
|
|
402
|
+
|
|
403
|
+
Note:
|
|
404
|
+
- The merged state dict is recomputed if merge weights have changed
|
|
405
|
+
- This allows for dynamic behavior during training as weights are updated
|
|
406
|
+
- The computation is efficient as merging only happens when needed
|
|
407
|
+
"""
|
|
230
408
|
if self._merged_state_dict is None:
|
|
231
409
|
self.merge_weights()
|
|
232
410
|
return self.forward_model(args=args, kwargs=kwargs)
|
|
@@ -1,9 +1,88 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base Program Classes for FusionBench.
|
|
3
|
+
|
|
4
|
+
This module defines the foundational abstract base classes for FusionBench programs.
|
|
5
|
+
These programs serve as the main execution units that orchestrate model fusion
|
|
6
|
+
workflows, from loading configurations to executing fusion algorithms and
|
|
7
|
+
evaluating results.
|
|
8
|
+
|
|
9
|
+
The base classes provide a consistent interface for all FusionBench programs
|
|
10
|
+
while allowing for flexible implementations of different fusion workflows.
|
|
11
|
+
"""
|
|
12
|
+
|
|
1
13
|
from abc import abstractmethod
|
|
2
14
|
|
|
3
|
-
from fusion_bench.mixins import
|
|
15
|
+
from fusion_bench.mixins import BaseYAMLSerializable
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BaseHydraProgram(BaseYAMLSerializable):
|
|
19
|
+
"""
|
|
20
|
+
Abstract base class for all FusionBench programs that use Hydra configuration.
|
|
21
|
+
|
|
22
|
+
This class serves as the foundation for all FusionBench execution programs,
|
|
23
|
+
providing a standardized interface for configuration-driven model fusion
|
|
24
|
+
workflows. It combines the serialization capabilities of BaseYAMLSerializable
|
|
25
|
+
with the requirement for a main execution method.
|
|
26
|
+
|
|
27
|
+
The class is designed to work seamlessly with Hydra's configuration management
|
|
28
|
+
system, allowing programs to be instantiated and configured through YAML files.
|
|
29
|
+
This enables flexible, reproducible experiments with different fusion algorithms,
|
|
30
|
+
model pools, and evaluation tasks.
|
|
31
|
+
|
|
32
|
+
Key Features:
|
|
33
|
+
|
|
34
|
+
- Configuration-driven execution through Hydra integration
|
|
35
|
+
- YAML serialization support for experiment reproducibility
|
|
36
|
+
- Abstract interface ensuring consistent program structure
|
|
37
|
+
- Integration with FusionBench's modular architecture
|
|
4
38
|
|
|
39
|
+
Typical Usage:
|
|
40
|
+
Subclasses should implement the `run()` method to define their specific
|
|
41
|
+
fusion workflow. The program can then be executed through the FusionBench
|
|
42
|
+
CLI or instantiated directly from configuration files.
|
|
43
|
+
|
|
44
|
+
Example:
|
|
45
|
+
```python
|
|
46
|
+
class MyFusionProgram(BaseHydraProgram):
|
|
47
|
+
def __init__(self, method_config, modelpool_config, taskpool_config):
|
|
48
|
+
self.method_config = method_config
|
|
49
|
+
self.modelpool_config = modelpool_config
|
|
50
|
+
self.taskpool_config = taskpool_config
|
|
51
|
+
|
|
52
|
+
def run(self):
|
|
53
|
+
# Load components
|
|
54
|
+
algorithm = load_algorithm(self.method_config)
|
|
55
|
+
modelpool = load_modelpool(self.modelpool_config)
|
|
56
|
+
taskpool = load_taskpool(self.taskpool_config)
|
|
57
|
+
|
|
58
|
+
# Execute fusion
|
|
59
|
+
merged_model = algorithm.run(modelpool)
|
|
60
|
+
|
|
61
|
+
# Evaluate results
|
|
62
|
+
report = taskpool.evaluate(merged_model)
|
|
63
|
+
return report
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
Note:
|
|
67
|
+
This is an abstract base class and cannot be instantiated directly.
|
|
68
|
+
Subclasses must implement the `run()` method to provide concrete
|
|
69
|
+
functionality.
|
|
70
|
+
|
|
71
|
+
See Also:
|
|
72
|
+
|
|
73
|
+
- [FabricModelFusionProgram][fusion_bench.programs.FabricModelFusionProgram]: Lightning Fabric-based implementation
|
|
74
|
+
- [BaseYAMLSerializable][fusion_bench.mixins.BaseYAMLSerializable]: Parent class providing serialization
|
|
75
|
+
- FusionBench CLI documentation for program execution details
|
|
76
|
+
"""
|
|
5
77
|
|
|
6
|
-
class BaseHydraProgram(BaseYAMLSerializableModel):
|
|
7
78
|
@abstractmethod
|
|
8
79
|
def run(self):
|
|
80
|
+
"""
|
|
81
|
+
Execute the main program workflow.
|
|
82
|
+
|
|
83
|
+
This abstract method defines the primary entry point for program execution.
|
|
84
|
+
Subclasses must implement this method to define their specific fusion
|
|
85
|
+
workflow, including model loading, fusion algorithm execution, and
|
|
86
|
+
result evaluation.
|
|
87
|
+
"""
|
|
9
88
|
pass
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
|
-
from typing import Callable, Dict, Iterable, Optional, Union # noqa: F401
|
|
4
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
|
|
5
5
|
|
|
6
6
|
import lightning as L
|
|
7
7
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
@@ -18,8 +18,8 @@ from fusion_bench.taskpool import BaseTaskPool
|
|
|
18
18
|
from fusion_bench.utils import import_object, instantiate, timeit_context
|
|
19
19
|
from fusion_bench.utils.hydra_utils import get_hydra_output_dir
|
|
20
20
|
from fusion_bench.utils.json import print_json
|
|
21
|
-
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
22
21
|
from fusion_bench.utils.pylogger import getRankZeroLogger
|
|
22
|
+
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
23
23
|
|
|
24
24
|
log = getRankZeroLogger(__name__)
|
|
25
25
|
|
|
@@ -39,6 +39,7 @@ class FabricModelFusionProgram(
|
|
|
39
39
|
"_fabric": "fabric",
|
|
40
40
|
"fast_dev_run": "fast_dev_run",
|
|
41
41
|
"seed": "seed",
|
|
42
|
+
"path": "path",
|
|
42
43
|
}
|
|
43
44
|
|
|
44
45
|
def __init__(
|
|
@@ -56,6 +57,7 @@ class FabricModelFusionProgram(
|
|
|
56
57
|
fast_dev_run: bool = False,
|
|
57
58
|
seed: Optional[int] = None,
|
|
58
59
|
print_function_call: bool = True,
|
|
60
|
+
path: DictConfig = None,
|
|
59
61
|
**kwargs,
|
|
60
62
|
):
|
|
61
63
|
self._method = method
|
|
@@ -67,6 +69,7 @@ class FabricModelFusionProgram(
|
|
|
67
69
|
self.merged_model_save_kwargs = merged_model_save_kwargs
|
|
68
70
|
self.fast_dev_run = fast_dev_run
|
|
69
71
|
self.seed = seed
|
|
72
|
+
self.path = path
|
|
70
73
|
fusion_bench.utils.instantiate_utils.PRINT_FUNCTION_CALL = print_function_call
|
|
71
74
|
super().__init__(**kwargs)
|
|
72
75
|
|
|
@@ -164,9 +167,9 @@ class FabricModelFusionProgram(
|
|
|
164
167
|
self,
|
|
165
168
|
taskpool: BaseTaskPool,
|
|
166
169
|
merged_model: Union[nn.Module, Dict, Iterable],
|
|
167
|
-
*args,
|
|
168
|
-
**kwargs,
|
|
169
|
-
):
|
|
170
|
+
*args: Any,
|
|
171
|
+
**kwargs: Any,
|
|
172
|
+
) -> Union[Dict, List, Any]:
|
|
170
173
|
"""
|
|
171
174
|
Evaluates the merged model using the provided task pool.
|
|
172
175
|
|
|
@@ -243,7 +246,10 @@ class FabricModelFusionProgram(
|
|
|
243
246
|
compat_load_fn="fusion_bench.compat.taskpool.load_taskpool_from_config",
|
|
244
247
|
)
|
|
245
248
|
|
|
249
|
+
self.method.on_run_start()
|
|
246
250
|
merged_model = self.method.run(self.modelpool)
|
|
251
|
+
self.method.on_run_end()
|
|
252
|
+
|
|
247
253
|
if merged_model is None:
|
|
248
254
|
log.info(
|
|
249
255
|
"No merged model returned by the method. Skipping saving and evaluation."
|
|
@@ -261,8 +267,13 @@ class FabricModelFusionProgram(
|
|
|
261
267
|
if self.report_save_path is not None:
|
|
262
268
|
# save report (Dict) to a file
|
|
263
269
|
# if the directory of `save_report` does not exists, create it
|
|
264
|
-
if
|
|
265
|
-
|
|
270
|
+
if (
|
|
271
|
+
"{log_dir}" in self.report_save_path
|
|
272
|
+
and self.log_dir is not None
|
|
273
|
+
):
|
|
274
|
+
self.report_save_path = self.report_save_path.format(
|
|
275
|
+
log_dir=self.log_dir
|
|
276
|
+
)
|
|
266
277
|
os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
|
|
267
278
|
json.dump(report, open(self.report_save_path, "w"))
|
|
268
279
|
else:
|
|
@@ -294,11 +305,16 @@ class FabricModelFusionProgram(
|
|
|
294
305
|
hydra_output_dir = None
|
|
295
306
|
|
|
296
307
|
if hydra_output_dir is not None:
|
|
308
|
+
if os.path.abspath(hydra_output_dir) == os.path.abspath(self.log_dir):
|
|
309
|
+
return
|
|
310
|
+
|
|
297
311
|
os.makedirs(self.log_dir, exist_ok=True)
|
|
298
312
|
try:
|
|
299
313
|
# if the system is windows, use the `mklink` command in "CMD" to create the symlink
|
|
300
314
|
if os.name == "nt":
|
|
301
|
-
os.system(
|
|
315
|
+
os.system(
|
|
316
|
+
f"mklink /J {os.path.abspath(os.path.join(self.log_dir, 'hydra_output_' + os.path.basename(hydra_output_dir)))} {os.path.abspath(hydra_output_dir)}"
|
|
317
|
+
)
|
|
302
318
|
else:
|
|
303
319
|
os.symlink(
|
|
304
320
|
hydra_output_dir,
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
#!/usr/bin/env python3
|
|
2
2
|
"""
|
|
3
|
-
This is the CLI script that is executed when the user runs the `
|
|
3
|
+
This is the CLI script that is executed when the user runs the `fusion_bench` command.
|
|
4
4
|
The script is responsible for parsing the command-line arguments, loading the configuration file, and running the fusion algorithm.
|
|
5
5
|
"""
|
|
6
6
|
|
|
@@ -14,17 +14,17 @@ from omegaconf import DictConfig, OmegaConf
|
|
|
14
14
|
|
|
15
15
|
from fusion_bench.programs import BaseHydraProgram
|
|
16
16
|
from fusion_bench.utils import instantiate
|
|
17
|
+
from fusion_bench.constants import PROJECT_ROOT_PATH
|
|
17
18
|
|
|
18
19
|
log = logging.getLogger(__name__)
|
|
19
20
|
|
|
20
21
|
|
|
21
22
|
def _get_default_config_path():
|
|
22
23
|
for config_dir in ["fusion_bench_config", "config"]:
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
return os.path.abspath(config_path)
|
|
24
|
+
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
25
|
+
config_path = os.path.join(config_path_root, config_dir)
|
|
26
|
+
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
27
|
+
return os.path.abspath(config_path)
|
|
28
28
|
return None
|
|
29
29
|
|
|
30
30
|
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
from abc import abstractmethod
|
|
2
|
+
from typing import Any, Dict
|
|
2
3
|
|
|
3
|
-
from fusion_bench.mixins import
|
|
4
|
+
from fusion_bench.mixins import BaseYAMLSerializable
|
|
4
5
|
|
|
5
6
|
|
|
6
|
-
class BaseTaskPool(
|
|
7
|
+
class BaseTaskPool(BaseYAMLSerializable):
|
|
7
8
|
_program = None
|
|
8
9
|
_config_key = "taskpool"
|
|
9
10
|
|
|
10
11
|
@abstractmethod
|
|
11
|
-
def evaluate(self, model, *args, **kwargs):
|
|
12
|
+
def evaluate(self, model: Any, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
|
12
13
|
"""
|
|
13
14
|
Evaluate the model on all tasks in the task pool, and return a report.
|
|
14
15
|
|