fusion-bench 0.2.19__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -1
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +16 -6
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +3 -0
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +199 -0
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +383 -0
- fusion_bench/method/simple_average.py +16 -4
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +43 -12
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +6 -6
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/misc.py +48 -2
- fusion_bench/utils/modelscope.py +265 -0
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +34 -27
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +31 -24
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +189 -153
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/regmean/clip_regmean.yaml +1 -1
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +11 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8.yaml +34 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +14 -17
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_model_only.yaml +14 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL10.yaml +39 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL12.yaml +49 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14.yaml +55 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL14_model_only.yaml +21 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL16.yaml +61 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL18.yaml +67 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20.yaml +73 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TALL20_model_only.yaml +26 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +7 -5
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +6 -10
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_cars.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_and_dtd.yaml +6 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_sun397_cars_and_dtd.yaml +7 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +8 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +4 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +32 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +14 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20.yaml +73 -8
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TALL20_model_only.yaml +27 -7
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +6 -10
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.19.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
|
2
|
+
|
|
3
|
+
from .configuration_smile_llama import SmileLlamaConfig
|
|
4
|
+
from .modeling_smile_llama import SmileLlamaForCausalLM, SmileLlamaModel
|
|
5
|
+
|
|
6
|
+
AutoConfig.register("smile_llama", SmileLlamaConfig)
|
|
7
|
+
AutoModel.register(SmileLlamaConfig, SmileLlamaModel)
|
|
8
|
+
AutoModelForCausalLM.register(SmileLlamaConfig, SmileLlamaForCausalLM)
|
|
@@ -1,48 +1,6 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
OptionalDependencyNotAvailable,
|
|
6
|
-
_LazyModule,
|
|
7
|
-
is_flax_available,
|
|
8
|
-
is_tf_available,
|
|
9
|
-
is_torch_available,
|
|
1
|
+
from .configuration_smile_mistral import SmileMistralConfig
|
|
2
|
+
from .modeling_smile_mistral import (
|
|
3
|
+
SmileMistralForCausalLM,
|
|
4
|
+
SmileMistralModel,
|
|
10
5
|
)
|
|
11
|
-
|
|
12
|
-
_import_structure = {
|
|
13
|
-
"configuration_smile_mistral": ["SmileMistralConfig"],
|
|
14
|
-
}
|
|
15
|
-
|
|
16
|
-
try:
|
|
17
|
-
if not is_torch_available():
|
|
18
|
-
raise OptionalDependencyNotAvailable()
|
|
19
|
-
except OptionalDependencyNotAvailable:
|
|
20
|
-
pass
|
|
21
|
-
else:
|
|
22
|
-
_import_structure["modeling_smile_mistral"] = [
|
|
23
|
-
"SmileMistralForCausalLM",
|
|
24
|
-
"SmileMistralModel",
|
|
25
|
-
"SmileMistralPreTrainedModel",
|
|
26
|
-
]
|
|
27
|
-
|
|
28
|
-
if TYPE_CHECKING:
|
|
29
|
-
from .configuration_smile_mistral import SmileMistralConfig
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
if not is_torch_available():
|
|
33
|
-
raise OptionalDependencyNotAvailable()
|
|
34
|
-
except OptionalDependencyNotAvailable:
|
|
35
|
-
pass
|
|
36
|
-
else:
|
|
37
|
-
from .modeling_smile_mistral import (
|
|
38
|
-
SmileMistralForCausalLM,
|
|
39
|
-
SmileMistralModel,
|
|
40
|
-
SmileMistralPreTrainedModel,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
else:
|
|
44
|
-
import sys
|
|
45
|
-
|
|
46
|
-
sys.modules[__name__] = _LazyModule(
|
|
47
|
-
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
|
|
48
|
-
)
|
|
6
|
+
from . import register
|
|
@@ -24,7 +24,6 @@ from transformers.modeling_outputs import (
|
|
|
24
24
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
25
25
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
26
26
|
from transformers.models.qwen2.modeling_qwen2 import (
|
|
27
|
-
QWEN2_INPUTS_DOCSTRING,
|
|
28
27
|
Qwen2RMSNorm,
|
|
29
28
|
Qwen2RotaryEmbedding,
|
|
30
29
|
apply_rotary_pos_emb,
|
|
@@ -314,7 +313,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
|
|
|
314
313
|
self.embed_tokens = value
|
|
315
314
|
|
|
316
315
|
@can_return_tuple
|
|
317
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
318
316
|
def forward(
|
|
319
317
|
self,
|
|
320
318
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -646,7 +644,6 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
|
646
644
|
|
|
647
645
|
@can_return_tuple
|
|
648
646
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
649
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
650
647
|
@replace_return_docstrings(
|
|
651
648
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
652
649
|
)
|
|
@@ -752,7 +749,9 @@ class SmileQwen2ForSequenceClassification(SmileQwen2PreTrainedModel):
|
|
|
752
749
|
def __init__(self, config):
|
|
753
750
|
super().__init__(config)
|
|
754
751
|
self.num_labels = config.num_labels
|
|
755
|
-
self.model = SmileQwen2Model(
|
|
752
|
+
self.model = SmileQwen2Model(
|
|
753
|
+
config
|
|
754
|
+
) # * replace Qwen2Model with SmileQwen2Model
|
|
756
755
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
|
757
756
|
|
|
758
757
|
# Initialize weights and apply final processing
|
|
@@ -765,7 +764,6 @@ class SmileQwen2ForSequenceClassification(SmileQwen2PreTrainedModel):
|
|
|
765
764
|
self.model.embed_tokens = value
|
|
766
765
|
|
|
767
766
|
@can_return_tuple
|
|
768
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
769
767
|
def forward(
|
|
770
768
|
self,
|
|
771
769
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -852,7 +850,9 @@ class SmileQwen2ForQuestionAnswering(SmileQwen2PreTrainedModel):
|
|
|
852
850
|
|
|
853
851
|
def __init__(self, config):
|
|
854
852
|
super().__init__(config)
|
|
855
|
-
self.transformer = SmileQwen2Model(
|
|
853
|
+
self.transformer = SmileQwen2Model(
|
|
854
|
+
config
|
|
855
|
+
) # * replace Qwen2Model with SmileQwen2Model
|
|
856
856
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
|
857
857
|
|
|
858
858
|
# Initialize weights and apply final processing
|
|
@@ -865,7 +865,6 @@ class SmileQwen2ForQuestionAnswering(SmileQwen2PreTrainedModel):
|
|
|
865
865
|
self.transformer.embed_tokens = value
|
|
866
866
|
|
|
867
867
|
@can_return_tuple
|
|
868
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
869
868
|
def forward(
|
|
870
869
|
self,
|
|
871
870
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
|
2
2
|
|
|
3
3
|
from .configuration_smile_qwen2 import SmileQwen2Config
|
|
4
|
-
from .modeling_smile_qwen2 import
|
|
5
|
-
SmileQwen2ForCausalLM,
|
|
6
|
-
SmileQwen2Model,
|
|
7
|
-
)
|
|
4
|
+
from .modeling_smile_qwen2 import SmileQwen2ForCausalLM, SmileQwen2Model
|
|
8
5
|
|
|
9
6
|
AutoConfig.register("smile_qwen2", SmileQwen2Config)
|
|
10
7
|
AutoModel.register(SmileQwen2Config, SmileQwen2Model)
|
|
@@ -11,6 +11,7 @@ from torch.func import functional_call
|
|
|
11
11
|
from torch.nn import functional as F
|
|
12
12
|
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
14
15
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
15
16
|
state_dict_sub,
|
|
16
17
|
state_dict_weighted_sum,
|
|
@@ -20,59 +21,6 @@ from fusion_bench.utils.type import StateDictType
|
|
|
20
21
|
log = logging.getLogger(__name__)
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
def join_list(list_of_list: List[List]):
|
|
24
|
-
ans = []
|
|
25
|
-
for l in list_of_list:
|
|
26
|
-
ans.extend(l)
|
|
27
|
-
return ans
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def del_attr(obj, names: List[str]):
|
|
31
|
-
"""
|
|
32
|
-
Deletes an attribute from an object recursively.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
obj (object): Object to delete attribute from.
|
|
36
|
-
names (list): List of attribute names to delete recursively.
|
|
37
|
-
"""
|
|
38
|
-
if len(names) == 1:
|
|
39
|
-
delattr(obj, names[0])
|
|
40
|
-
else:
|
|
41
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def set_attr(obj, names: List[str], val):
|
|
45
|
-
"""
|
|
46
|
-
Sets an attribute of an object recursively.
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
obj (object): Object to set attribute of.
|
|
50
|
-
names (list): List of attribute names to set recursively.
|
|
51
|
-
val (object): Value to set the attribute to.
|
|
52
|
-
"""
|
|
53
|
-
if len(names) == 1:
|
|
54
|
-
setattr(obj, names[0], val)
|
|
55
|
-
else:
|
|
56
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def get_attr(obj, names: List[str]):
|
|
60
|
-
"""
|
|
61
|
-
Gets an attribute of an object recursively.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
obj (object): Object to get attribute of.
|
|
65
|
-
names (list): List of attribute names to get recursively.
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
object: The attribute of the object.
|
|
69
|
-
"""
|
|
70
|
-
if len(names) == 1:
|
|
71
|
-
return getattr(obj, names[0])
|
|
72
|
-
else:
|
|
73
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
74
|
-
|
|
75
|
-
|
|
76
24
|
class Depth_0_Gate(nn.Module):
|
|
77
25
|
def __init__(self, num_experts: int):
|
|
78
26
|
super().__init__()
|
fusion_bench/models/utils.py
CHANGED
|
@@ -3,6 +3,8 @@ from typing import List
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
5
|
|
|
6
|
+
from fusion_bench.utils.type import StateDictType
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def del_attr(obj, names: List[str]):
|
|
8
10
|
"""
|
|
@@ -50,6 +52,30 @@ def get_attr(obj, names: List[str]):
|
|
|
50
52
|
return get_attr(getattr(obj, names[0]), names[1:])
|
|
51
53
|
|
|
52
54
|
|
|
55
|
+
def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Checks that the parameter names of the given checkpoints match.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
checkpoints (List[Dict[str, float]]): A list of checkpoints, where each checkpoint is a dictionary of parameter names and their corresponding values.
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If the number of checkpoints is less than 2 or if the parameter names of any two checkpoints differ.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
parameter_names = set(checkpoints[0].keys())
|
|
67
|
+
|
|
68
|
+
if len(checkpoints) >= 2:
|
|
69
|
+
# raise ValueError("Number of models is less than 2.")
|
|
70
|
+
for checkpoint in checkpoints[1:]:
|
|
71
|
+
current_parameterNames = set(checkpoint.keys())
|
|
72
|
+
if current_parameterNames != parameter_names:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
"Differing parameter names in models. "
|
|
75
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
53
79
|
def find_layers_with_type(
|
|
54
80
|
module: nn.Module,
|
|
55
81
|
layer_types=[nn.Linear],
|
fusion_bench/models/we_moe.py
CHANGED
|
@@ -8,64 +8,12 @@ from torch import Tensor, nn
|
|
|
8
8
|
from torch.func import functional_call
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
|
|
11
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
11
12
|
from fusion_bench.utils.type import StateDictType
|
|
12
13
|
|
|
13
14
|
log = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
def join_list(list_of_list: List[List]):
|
|
17
|
-
ans = []
|
|
18
|
-
for l in list_of_list:
|
|
19
|
-
ans.extend(l)
|
|
20
|
-
return ans
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def del_attr(obj, names: List[str]):
|
|
24
|
-
"""
|
|
25
|
-
Deletes an attribute from an object recursively.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
obj (object): Object to delete attribute from.
|
|
29
|
-
names (list): List of attribute names to delete recursively.
|
|
30
|
-
"""
|
|
31
|
-
if len(names) == 1:
|
|
32
|
-
delattr(obj, names[0])
|
|
33
|
-
else:
|
|
34
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def set_attr(obj, names: List[str], val):
|
|
38
|
-
"""
|
|
39
|
-
Sets an attribute of an object recursively.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
obj (object): Object to set attribute of.
|
|
43
|
-
names (list): List of attribute names to set recursively.
|
|
44
|
-
val (object): Value to set the attribute to.
|
|
45
|
-
"""
|
|
46
|
-
if len(names) == 1:
|
|
47
|
-
setattr(obj, names[0], val)
|
|
48
|
-
else:
|
|
49
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def get_attr(obj, names: List[str]):
|
|
53
|
-
"""
|
|
54
|
-
Gets an attribute of an object recursively.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
obj (object): Object to get attribute of.
|
|
58
|
-
names (list): List of attribute names to get recursively.
|
|
59
|
-
|
|
60
|
-
Returns:
|
|
61
|
-
object: The attribute of the object.
|
|
62
|
-
"""
|
|
63
|
-
if len(names) == 1:
|
|
64
|
-
return getattr(obj, names[0])
|
|
65
|
-
else:
|
|
66
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
67
|
-
|
|
68
|
-
|
|
69
17
|
class Depth_0_Gate(nn.Module):
|
|
70
18
|
def __init__(self, num_experts: int):
|
|
71
19
|
super().__init__()
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable, Dict, List, cast
|
|
1
|
+
from typing import Any, Callable, Dict, List, Union, cast
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -6,7 +6,9 @@ from omegaconf import ListConfig
|
|
|
6
6
|
from torch import Tensor, nn
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def aggregate_tensors(
|
|
9
|
+
def aggregate_tensors(
|
|
10
|
+
outputs: List[Any], aggregate_fn: Callable
|
|
11
|
+
) -> Union[Tensor, Dict, List, None]:
|
|
10
12
|
"""
|
|
11
13
|
Aggregates a list of outputs using the provided aggregation function.
|
|
12
14
|
|
|
@@ -84,7 +86,7 @@ class EnsembleModule(nn.Module):
|
|
|
84
86
|
"""
|
|
85
87
|
return torch.stack(outputs).mean(dim=0)
|
|
86
88
|
|
|
87
|
-
def forward(self, *args, **kwargs):
|
|
89
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
88
90
|
"""
|
|
89
91
|
Performs a forward pass by averaging the outputs of the models.
|
|
90
92
|
|
|
@@ -150,7 +152,7 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
150
152
|
weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
|
|
151
153
|
return (torch.stack(outputs) * weights).sum(dim=0)
|
|
152
154
|
|
|
153
|
-
def forward(self, *args, **kwargs):
|
|
155
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
154
156
|
"""
|
|
155
157
|
Performs a forward pass by computing the weighted average of the models' outputs.
|
|
156
158
|
|
|
@@ -49,7 +49,7 @@ def get_layer_wise_weights(
|
|
|
49
49
|
return torch.full((num_models, num_layers), init_values, dtype=dtype)
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
|
|
52
|
+
def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]) -> Tensor:
|
|
53
53
|
"""
|
|
54
54
|
Fuse the layer-wise weights with the given state dictionaries.
|
|
55
55
|
|