fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -1,8 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Base algorithm classes for model fusion.
|
|
3
|
+
|
|
4
|
+
This module provides the foundational abstract base class for implementing model fusion
|
|
5
|
+
algorithms in the FusionBench framework. It defines the standard interface and lifecycle
|
|
6
|
+
hooks that all fusion algorithms should follow.
|
|
7
|
+
|
|
8
|
+
The main class `BaseAlgorithm` serves as a template for creating various model fusion
|
|
9
|
+
strategies such as simple averaging, task arithmetic, weight interpolation, and more
|
|
10
|
+
advanced techniques. It integrates with the YAML configuration system and provides
|
|
11
|
+
hooks for setup and cleanup operations.
|
|
12
|
+
|
|
13
|
+
Classes:
|
|
14
|
+
BaseAlgorithm: Abstract base class for all model fusion algorithms.
|
|
15
|
+
BaseModelFusionAlgorithm: Alias for BaseAlgorithm (backward compatibility).
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
Implementing a custom fusion algorithm:
|
|
19
|
+
|
|
20
|
+
>>> from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
21
|
+
>>> from fusion_bench.modelpool import BaseModelPool
|
|
22
|
+
>>>
|
|
23
|
+
>>> class WeightedAverageAlgorithm(BaseAlgorithm):
|
|
24
|
+
... def __init__(self, weights=None, **kwargs):
|
|
25
|
+
... self.register_parameter_to_config("weights", "weights", weights or [])
|
|
26
|
+
... super().__init__(**kwargs)
|
|
27
|
+
...
|
|
28
|
+
... def run(self, modelpool: BaseModelPool):
|
|
29
|
+
... models = list(modelpool)
|
|
30
|
+
... if len(self.weights) != len(models):
|
|
31
|
+
... raise ValueError("Number of weights must match number of models")
|
|
32
|
+
...
|
|
33
|
+
... # Implement weighted averaging logic here
|
|
34
|
+
... return fused_model
|
|
35
|
+
"""
|
|
36
|
+
|
|
1
37
|
import logging
|
|
2
38
|
from abc import abstractmethod
|
|
3
39
|
from typing import Optional # noqa: F401
|
|
4
40
|
|
|
5
|
-
from fusion_bench.mixins import
|
|
41
|
+
from fusion_bench.mixins import BaseYAMLSerializable
|
|
6
42
|
from fusion_bench.modelpool import BaseModelPool
|
|
7
43
|
|
|
8
44
|
__all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
|
|
@@ -10,36 +46,183 @@ __all__ = ["BaseAlgorithm", "BaseModelFusionAlgorithm"]
|
|
|
10
46
|
log = logging.getLogger(__name__)
|
|
11
47
|
|
|
12
48
|
|
|
13
|
-
class BaseAlgorithm(
|
|
49
|
+
class BaseAlgorithm(BaseYAMLSerializable):
|
|
14
50
|
"""
|
|
15
51
|
Base class for model fusion algorithms.
|
|
16
52
|
|
|
17
|
-
This class provides a
|
|
18
|
-
|
|
53
|
+
This abstract class provides a standardized interface for implementing model fusion
|
|
54
|
+
algorithms. It inherits from BaseYAMLSerializable to support configuration loading
|
|
55
|
+
from YAML files.
|
|
56
|
+
|
|
57
|
+
The class follows a template method pattern where subclasses must implement the
|
|
58
|
+
core fusion logic in the `run` method, while optional lifecycle hooks allow for
|
|
59
|
+
setup and cleanup operations.
|
|
60
|
+
|
|
61
|
+
Attributes:
|
|
62
|
+
_program: Optional program reference for algorithm execution context.
|
|
63
|
+
_config_key (str): Configuration key used for YAML serialization, defaults to "method".
|
|
64
|
+
|
|
65
|
+
Examples:
|
|
66
|
+
Creating a simple averaging algorithm:
|
|
67
|
+
|
|
68
|
+
>>> class SimpleAverageAlgorithm(BaseAlgorithm):
|
|
69
|
+
... def run(self, modelpool: BaseModelPool):
|
|
70
|
+
... # Implementation of model averaging logic
|
|
71
|
+
... return averaged_model
|
|
72
|
+
...
|
|
73
|
+
>>> algorithm = SimpleAverageAlgorithm()
|
|
74
|
+
>>> merged_model = algorithm.run(modelpool)
|
|
75
|
+
|
|
76
|
+
Loading algorithm from YAML configuration:
|
|
77
|
+
|
|
78
|
+
>>> algorithm = BaseAlgorithm.from_yaml("config.yaml")
|
|
79
|
+
>>> result = algorithm.run(modelpool)
|
|
80
|
+
|
|
81
|
+
Note:
|
|
82
|
+
Subclasses must implement the abstract `run` method to define the specific
|
|
83
|
+
fusion strategy (e.g., simple averaging, task arithmetic, etc.).
|
|
19
84
|
"""
|
|
20
85
|
|
|
21
86
|
_program = None
|
|
22
87
|
_config_key = "method"
|
|
23
88
|
|
|
89
|
+
def on_run_start(self):
|
|
90
|
+
"""
|
|
91
|
+
Lifecycle hook called at the beginning of algorithm execution.
|
|
92
|
+
|
|
93
|
+
This method is invoked before the main `run` method executes, providing
|
|
94
|
+
an opportunity for subclasses to perform initialization tasks such as:
|
|
95
|
+
|
|
96
|
+
- Setting up logging or monitoring
|
|
97
|
+
- Initializing algorithm-specific state
|
|
98
|
+
- Validating prerequisites
|
|
99
|
+
- Preparing computational resources
|
|
100
|
+
|
|
101
|
+
The default implementation does nothing, allowing subclasses to override
|
|
102
|
+
as needed for their specific requirements.
|
|
103
|
+
|
|
104
|
+
Examples:
|
|
105
|
+
>>> class MyAlgorithm(BaseAlgorithm):
|
|
106
|
+
... def on_run_start(self):
|
|
107
|
+
... super().on_run_start()
|
|
108
|
+
... print("Starting model fusion...")
|
|
109
|
+
... self.start_time = time.time()
|
|
110
|
+
"""
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
def on_run_end(self):
|
|
114
|
+
"""
|
|
115
|
+
Lifecycle hook called at the end of algorithm execution.
|
|
116
|
+
|
|
117
|
+
This method is invoked after the main `run` method completes, providing
|
|
118
|
+
an opportunity for subclasses to perform cleanup and finalization tasks such as:
|
|
119
|
+
|
|
120
|
+
- Logging execution statistics or results
|
|
121
|
+
- Cleaning up temporary resources
|
|
122
|
+
- Saving intermediate results or metrics
|
|
123
|
+
- Releasing computational resources
|
|
124
|
+
|
|
125
|
+
The method is called regardless of whether the `run` method succeeded or failed,
|
|
126
|
+
making it suitable for cleanup operations that should always occur.
|
|
127
|
+
|
|
128
|
+
The default implementation does nothing, allowing subclasses to override
|
|
129
|
+
as needed for their specific requirements.
|
|
130
|
+
|
|
131
|
+
Examples:
|
|
132
|
+
>>> class MyAlgorithm(BaseAlgorithm):
|
|
133
|
+
... def on_run_end(self):
|
|
134
|
+
... super().on_run_end()
|
|
135
|
+
... elapsed = time.time() - self.start_time
|
|
136
|
+
... print(f"Fusion completed in {elapsed:.2f}s")
|
|
137
|
+
"""
|
|
138
|
+
pass
|
|
139
|
+
|
|
24
140
|
@abstractmethod
|
|
25
141
|
def run(self, modelpool: BaseModelPool):
|
|
26
142
|
"""
|
|
27
|
-
|
|
143
|
+
Execute the model fusion algorithm on the provided model pool.
|
|
144
|
+
|
|
145
|
+
This is the core method that must be implemented by all subclasses to define
|
|
146
|
+
their specific fusion strategy. The method takes a pool of models and produces
|
|
147
|
+
a fused result according to the algorithm's logic.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
modelpool (BaseModelPool): A collection of models to be fused. The modelpool
|
|
151
|
+
provides access to individual models and their metadata, allowing the
|
|
152
|
+
algorithm to iterate over models, access their parameters, and perform
|
|
153
|
+
fusion operations.
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
The type of return value depends on the specific algorithm implementation.
|
|
157
|
+
Common return types include:
|
|
158
|
+
|
|
159
|
+
- A single fused model (torch.nn.Module)
|
|
160
|
+
- A dictionary of fused models for multi-task scenarios
|
|
161
|
+
- Fusion results with additional metadata
|
|
162
|
+
- Custom data structures specific to the algorithm
|
|
28
163
|
|
|
29
|
-
|
|
164
|
+
Raises:
|
|
165
|
+
NotImplementedError: If called on the base class without implementation.
|
|
166
|
+
ValueError: If the modelpool is invalid or incompatible with the algorithm.
|
|
167
|
+
RuntimeError: If fusion fails due to model incompatibilities or other issues.
|
|
30
168
|
|
|
31
169
|
Examples:
|
|
32
|
-
|
|
33
|
-
>>> modelpool = ModelPool()
|
|
34
|
-
>>> merged_model = algorithm.run(modelpool)
|
|
170
|
+
Simple averaging implementation:
|
|
35
171
|
|
|
36
|
-
|
|
37
|
-
|
|
172
|
+
>>> def run(self, modelpool: BaseModelPool):
|
|
173
|
+
... models = [model for model in modelpool]
|
|
174
|
+
... averaged_params = {}
|
|
175
|
+
... for name in models[0].state_dict():
|
|
176
|
+
... averaged_params[name] = torch.stack([
|
|
177
|
+
... model.state_dict()[name] for model in models
|
|
178
|
+
... ]).mean(dim=0)
|
|
179
|
+
... fused_model = copy.deepcopy(models[0])
|
|
180
|
+
... fused_model.load_state_dict(averaged_params)
|
|
181
|
+
... return fused_model
|
|
182
|
+
|
|
183
|
+
Task arithmetic implementation:
|
|
184
|
+
|
|
185
|
+
>>> def run(self, modelpool: BaseModelPool):
|
|
186
|
+
... pretrained = modelpool.get_model('pretrained')
|
|
187
|
+
... task_vectors = []
|
|
188
|
+
... for model_name in modelpool.model_names:
|
|
189
|
+
... if model_name != 'pretrained':
|
|
190
|
+
... task_vector = self.compute_task_vector(
|
|
191
|
+
... modelpool.get_model(model_name), pretrained
|
|
192
|
+
... )
|
|
193
|
+
... task_vectors.append(task_vector)
|
|
194
|
+
... return self.merge_task_vectors(pretrained, task_vectors)
|
|
195
|
+
|
|
196
|
+
Note:
|
|
197
|
+
- The modelpool iteration order may affect results for non-commutative operations
|
|
198
|
+
- Ensure model compatibility (architecture, parameter shapes) before fusion
|
|
199
|
+
- Consider memory constraints when loading multiple large models
|
|
200
|
+
- Use appropriate device placement for GPU/CPU computation
|
|
38
201
|
"""
|
|
39
202
|
pass
|
|
40
203
|
|
|
41
204
|
|
|
42
205
|
BaseModelFusionAlgorithm = BaseAlgorithm
|
|
43
206
|
"""
|
|
44
|
-
Alias for
|
|
207
|
+
Alias for BaseAlgorithm class.
|
|
208
|
+
|
|
209
|
+
This alias is provided for backward compatibility and semantic clarity.
|
|
210
|
+
Some users may prefer the more explicit name 'BaseModelFusionAlgorithm'
|
|
211
|
+
to emphasize that this class is specifically designed for model fusion
|
|
212
|
+
tasks, while others may prefer the shorter 'BaseAlgorithm' name.
|
|
213
|
+
|
|
214
|
+
Both names refer to the exact same class and can be used interchangeably.
|
|
215
|
+
|
|
216
|
+
Examples:
|
|
217
|
+
Using the original name:
|
|
218
|
+
>>> class MyAlgorithm(BaseAlgorithm):
|
|
219
|
+
... def run(self, modelpool): pass
|
|
220
|
+
|
|
221
|
+
Using the alias:
|
|
222
|
+
>>> class MyAlgorithm(BaseModelFusionAlgorithm):
|
|
223
|
+
... def run(self, modelpool): pass
|
|
224
|
+
|
|
225
|
+
Note:
|
|
226
|
+
The alias is maintained for compatibility but BaseAlgorithm is the
|
|
227
|
+
preferred name for new implementations.
|
|
45
228
|
"""
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from tqdm.auto import tqdm
|
|
7
|
+
|
|
8
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
9
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
10
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
11
|
+
|
|
12
|
+
from .bitdelta_utils.data import get_dataloader, get_dataset
|
|
13
|
+
from .bitdelta_utils.diff import compress_diff, save_diff, save_full_model
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BitDeltaAlgorithm(
|
|
19
|
+
BaseAlgorithm,
|
|
20
|
+
LightningFabricMixin,
|
|
21
|
+
SimpleProfilerMixin,
|
|
22
|
+
):
|
|
23
|
+
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
24
|
+
"save_dir": "save_dir",
|
|
25
|
+
"save_full_model": "save_full_model",
|
|
26
|
+
"lr": "lr",
|
|
27
|
+
"batch_size": "batch_size",
|
|
28
|
+
"num_steps": "num_steps",
|
|
29
|
+
"dataset_name": "dataset_name",
|
|
30
|
+
"subset": "subset",
|
|
31
|
+
"split": "split",
|
|
32
|
+
"max_length": "max_length",
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
save_dir: str,
|
|
38
|
+
save_full_model: bool = False,
|
|
39
|
+
lr: float = 1e-4,
|
|
40
|
+
batch_size: int = 4,
|
|
41
|
+
num_steps: int = 100,
|
|
42
|
+
dataset_name: str = "c4",
|
|
43
|
+
subset: str = "en",
|
|
44
|
+
split: str = "train",
|
|
45
|
+
max_length: int = 128,
|
|
46
|
+
**kwargs,
|
|
47
|
+
):
|
|
48
|
+
super().__init__(**kwargs)
|
|
49
|
+
self.save_dir = save_dir
|
|
50
|
+
self.save_full_model = save_full_model
|
|
51
|
+
self.lr = lr
|
|
52
|
+
self.batch_size = batch_size
|
|
53
|
+
self.num_steps = num_steps
|
|
54
|
+
self.dataset_name = dataset_name
|
|
55
|
+
self.subset = subset
|
|
56
|
+
self.split = split
|
|
57
|
+
self.max_length = max_length
|
|
58
|
+
|
|
59
|
+
def run(self, modelpool: CausalLMPool):
|
|
60
|
+
if self.save_dir is None:
|
|
61
|
+
log.info(
|
|
62
|
+
f"save_dir not set, using log_dir instead. log_dir: {self.log_dir}"
|
|
63
|
+
)
|
|
64
|
+
self.save_dir = self.log_dir
|
|
65
|
+
|
|
66
|
+
with self.profile("model loading"):
|
|
67
|
+
tokenizer = modelpool.load_tokenizer()
|
|
68
|
+
base_model = modelpool.load_pretrained_model()
|
|
69
|
+
finetuned_model = modelpool.load_model(modelpool.model_names[0])
|
|
70
|
+
finetuned_compressed_model = modelpool.load_model(modelpool.model_names[0])
|
|
71
|
+
|
|
72
|
+
with self.profile("model compression"):
|
|
73
|
+
print(f"compressing diff...")
|
|
74
|
+
compress_diff(base_model, finetuned_model, finetuned_compressed_model)
|
|
75
|
+
|
|
76
|
+
# save untrained delta
|
|
77
|
+
save_diff(
|
|
78
|
+
finetuned_compressed_model, os.path.join(self.save_dir, "diff_untrained.pt")
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
optimizer = torch.optim.AdamW(
|
|
82
|
+
finetuned_compressed_model.parameters(), lr=self.lr
|
|
83
|
+
)
|
|
84
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
|
85
|
+
optimizer, self.num_steps
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
train_num_samples = self.batch_size * self.num_steps
|
|
89
|
+
train_dataset = get_dataset(
|
|
90
|
+
self.dataset_name,
|
|
91
|
+
self.subset,
|
|
92
|
+
"train",
|
|
93
|
+
size=train_num_samples,
|
|
94
|
+
)
|
|
95
|
+
train_dataloader = get_dataloader(
|
|
96
|
+
train_dataset,
|
|
97
|
+
tokenizer,
|
|
98
|
+
self.batch_size,
|
|
99
|
+
num_workers=4,
|
|
100
|
+
max_length=self.max_length,
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
bar = tqdm(train_dataloader)
|
|
104
|
+
|
|
105
|
+
train_loss_list = []
|
|
106
|
+
|
|
107
|
+
# Train loop
|
|
108
|
+
for step, batch in enumerate(bar):
|
|
109
|
+
batch1 = {k: v.to(finetuned_model.device) for k, v in batch.items()}
|
|
110
|
+
with torch.inference_mode():
|
|
111
|
+
finetuned_outputs = finetuned_model(**batch1)
|
|
112
|
+
|
|
113
|
+
batch2 = {
|
|
114
|
+
k: v.to(finetuned_compressed_model.device) for k, v in batch.items()
|
|
115
|
+
}
|
|
116
|
+
finetuned_compressed_outputs = finetuned_compressed_model(**batch2)
|
|
117
|
+
|
|
118
|
+
loss = F.mse_loss(
|
|
119
|
+
finetuned_outputs.logits.clone().to(
|
|
120
|
+
finetuned_compressed_outputs.logits.device
|
|
121
|
+
),
|
|
122
|
+
finetuned_compressed_outputs.logits,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
train_loss_list.append(loss.item())
|
|
126
|
+
|
|
127
|
+
optimizer.zero_grad()
|
|
128
|
+
loss.backward()
|
|
129
|
+
optimizer.step()
|
|
130
|
+
scheduler.step()
|
|
131
|
+
|
|
132
|
+
bar.set_description(f"train loss: {loss.item()}")
|
|
133
|
+
|
|
134
|
+
# save trained delta
|
|
135
|
+
save_diff(finetuned_compressed_model, os.path.join(self.save_dir, "diff.pt"))
|
|
136
|
+
|
|
137
|
+
if self.save_full_model:
|
|
138
|
+
print("saving uncalibrated model")
|
|
139
|
+
save_full_model(
|
|
140
|
+
base_model,
|
|
141
|
+
tokenizer,
|
|
142
|
+
os.path.join(self.save_dir, "diff_untrained.pt"),
|
|
143
|
+
os.path.join(self.save_dir, "uncalibrated_model"),
|
|
144
|
+
device="cpu",
|
|
145
|
+
)
|
|
146
|
+
print("saving calibrated model")
|
|
147
|
+
save_full_model(
|
|
148
|
+
base_model,
|
|
149
|
+
tokenizer,
|
|
150
|
+
os.path.join(self.save_dir, "diff.pt"),
|
|
151
|
+
os.path.join(self.save_dir, "calibrated_model"),
|
|
152
|
+
device="cpu",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
del base_model, finetuned_model, finetuned_compressed_model
|
|
156
|
+
torch.cuda.empty_cache()
|
|
File without changes
|