fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -1,48 +1,6 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
_LazyModule,
|
|
7
|
-
is_flax_available,
|
|
8
|
-
is_tf_available,
|
|
9
|
-
is_torch_available,
|
|
1
|
+
from . import register
|
|
2
|
+
from .configuration_smile_mistral import SmileMistralConfig
|
|
3
|
+
from .modeling_smile_mistral import (
|
|
4
|
+
SmileMistralForCausalLM,
|
|
5
|
+
SmileMistralModel,
|
|
10
6
|
)
|
|
11
|
-
|
|
12
|
-
_import_structure = {
|
|
13
|
-
"configuration_smile_mistral": ["SmileMistralConfig"],
|
|
14
|
-
}
|
|
15
|
-
|
|
16
|
-
try:
|
|
17
|
-
if not is_torch_available():
|
|
18
|
-
raise OptionalDependencyNotAvailable()
|
|
19
|
-
except OptionalDependencyNotAvailable:
|
|
20
|
-
pass
|
|
21
|
-
else:
|
|
22
|
-
_import_structure["modeling_smile_mistral"] = [
|
|
23
|
-
"SmileMistralForCausalLM",
|
|
24
|
-
"SmileMistralModel",
|
|
25
|
-
"SmileMistralPreTrainedModel",
|
|
26
|
-
]
|
|
27
|
-
|
|
28
|
-
if TYPE_CHECKING:
|
|
29
|
-
from .configuration_smile_mistral import SmileMistralConfig
|
|
30
|
-
|
|
31
|
-
try:
|
|
32
|
-
if not is_torch_available():
|
|
33
|
-
raise OptionalDependencyNotAvailable()
|
|
34
|
-
except OptionalDependencyNotAvailable:
|
|
35
|
-
pass
|
|
36
|
-
else:
|
|
37
|
-
from .modeling_smile_mistral import (
|
|
38
|
-
SmileMistralForCausalLM,
|
|
39
|
-
SmileMistralModel,
|
|
40
|
-
SmileMistralPreTrainedModel,
|
|
41
|
-
)
|
|
42
|
-
|
|
43
|
-
else:
|
|
44
|
-
import sys
|
|
45
|
-
|
|
46
|
-
sys.modules[__name__] = _LazyModule(
|
|
47
|
-
__name__, globals()["__file__"], _import_structure, module_spec=__spec__
|
|
48
|
-
)
|
|
@@ -24,7 +24,6 @@ from transformers.modeling_outputs import (
|
|
|
24
24
|
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
25
25
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
26
26
|
from transformers.models.qwen2.modeling_qwen2 import (
|
|
27
|
-
QWEN2_INPUTS_DOCSTRING,
|
|
28
27
|
Qwen2RMSNorm,
|
|
29
28
|
Qwen2RotaryEmbedding,
|
|
30
29
|
apply_rotary_pos_emb,
|
|
@@ -32,7 +31,6 @@ from transformers.models.qwen2.modeling_qwen2 import (
|
|
|
32
31
|
)
|
|
33
32
|
from transformers.processing_utils import Unpack
|
|
34
33
|
from transformers.utils import (
|
|
35
|
-
LossKwargs,
|
|
36
34
|
add_code_sample_docstrings,
|
|
37
35
|
add_start_docstrings,
|
|
38
36
|
add_start_docstrings_to_model_forward,
|
|
@@ -314,7 +312,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
|
|
|
314
312
|
self.embed_tokens = value
|
|
315
313
|
|
|
316
314
|
@can_return_tuple
|
|
317
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
318
315
|
def forward(
|
|
319
316
|
self,
|
|
320
317
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -609,9 +606,6 @@ class SmileQwen2Model(SmileQwen2PreTrainedModel):
|
|
|
609
606
|
return causal_mask
|
|
610
607
|
|
|
611
608
|
|
|
612
|
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
|
|
613
|
-
|
|
614
|
-
|
|
615
609
|
class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
616
610
|
_tied_weights_keys = ["lm_head.weight"]
|
|
617
611
|
_tp_plan = {"lm_head": "colwise_rep"}
|
|
@@ -646,7 +640,6 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
|
646
640
|
|
|
647
641
|
@can_return_tuple
|
|
648
642
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
649
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
650
643
|
@replace_return_docstrings(
|
|
651
644
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
|
652
645
|
)
|
|
@@ -663,7 +656,7 @@ class SmileQwen2ForCausalLM(SmileQwen2PreTrainedModel, GenerationMixin):
|
|
|
663
656
|
output_hidden_states: Optional[bool] = None,
|
|
664
657
|
cache_position: Optional[torch.LongTensor] = None,
|
|
665
658
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
666
|
-
**kwargs
|
|
659
|
+
**kwargs,
|
|
667
660
|
) -> CausalLMOutputWithPast:
|
|
668
661
|
r"""
|
|
669
662
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -752,7 +745,9 @@ class SmileQwen2ForSequenceClassification(SmileQwen2PreTrainedModel):
|
|
|
752
745
|
def __init__(self, config):
|
|
753
746
|
super().__init__(config)
|
|
754
747
|
self.num_labels = config.num_labels
|
|
755
|
-
self.model = SmileQwen2Model(
|
|
748
|
+
self.model = SmileQwen2Model(
|
|
749
|
+
config
|
|
750
|
+
) # * replace Qwen2Model with SmileQwen2Model
|
|
756
751
|
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
|
757
752
|
|
|
758
753
|
# Initialize weights and apply final processing
|
|
@@ -765,7 +760,6 @@ class SmileQwen2ForSequenceClassification(SmileQwen2PreTrainedModel):
|
|
|
765
760
|
self.model.embed_tokens = value
|
|
766
761
|
|
|
767
762
|
@can_return_tuple
|
|
768
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
769
763
|
def forward(
|
|
770
764
|
self,
|
|
771
765
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -852,7 +846,9 @@ class SmileQwen2ForQuestionAnswering(SmileQwen2PreTrainedModel):
|
|
|
852
846
|
|
|
853
847
|
def __init__(self, config):
|
|
854
848
|
super().__init__(config)
|
|
855
|
-
self.transformer = SmileQwen2Model(
|
|
849
|
+
self.transformer = SmileQwen2Model(
|
|
850
|
+
config
|
|
851
|
+
) # * replace Qwen2Model with SmileQwen2Model
|
|
856
852
|
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
|
857
853
|
|
|
858
854
|
# Initialize weights and apply final processing
|
|
@@ -865,7 +861,6 @@ class SmileQwen2ForQuestionAnswering(SmileQwen2PreTrainedModel):
|
|
|
865
861
|
self.transformer.embed_tokens = value
|
|
866
862
|
|
|
867
863
|
@can_return_tuple
|
|
868
|
-
@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
|
|
869
864
|
def forward(
|
|
870
865
|
self,
|
|
871
866
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -1,10 +1,7 @@
|
|
|
1
1
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
|
2
2
|
|
|
3
3
|
from .configuration_smile_qwen2 import SmileQwen2Config
|
|
4
|
-
from .modeling_smile_qwen2 import
|
|
5
|
-
SmileQwen2ForCausalLM,
|
|
6
|
-
SmileQwen2Model,
|
|
7
|
-
)
|
|
4
|
+
from .modeling_smile_qwen2 import SmileQwen2ForCausalLM, SmileQwen2Model
|
|
8
5
|
|
|
9
6
|
AutoConfig.register("smile_qwen2", SmileQwen2Config)
|
|
10
7
|
AutoModel.register(SmileQwen2Config, SmileQwen2Model)
|
|
@@ -11,6 +11,7 @@ from torch.func import functional_call
|
|
|
11
11
|
from torch.nn import functional as F
|
|
12
12
|
from tqdm.auto import tqdm
|
|
13
13
|
|
|
14
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
14
15
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
15
16
|
state_dict_sub,
|
|
16
17
|
state_dict_weighted_sum,
|
|
@@ -20,59 +21,6 @@ from fusion_bench.utils.type import StateDictType
|
|
|
20
21
|
log = logging.getLogger(__name__)
|
|
21
22
|
|
|
22
23
|
|
|
23
|
-
def join_list(list_of_list: List[List]):
|
|
24
|
-
ans = []
|
|
25
|
-
for l in list_of_list:
|
|
26
|
-
ans.extend(l)
|
|
27
|
-
return ans
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def del_attr(obj, names: List[str]):
|
|
31
|
-
"""
|
|
32
|
-
Deletes an attribute from an object recursively.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
obj (object): Object to delete attribute from.
|
|
36
|
-
names (list): List of attribute names to delete recursively.
|
|
37
|
-
"""
|
|
38
|
-
if len(names) == 1:
|
|
39
|
-
delattr(obj, names[0])
|
|
40
|
-
else:
|
|
41
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
def set_attr(obj, names: List[str], val):
|
|
45
|
-
"""
|
|
46
|
-
Sets an attribute of an object recursively.
|
|
47
|
-
|
|
48
|
-
Args:
|
|
49
|
-
obj (object): Object to set attribute of.
|
|
50
|
-
names (list): List of attribute names to set recursively.
|
|
51
|
-
val (object): Value to set the attribute to.
|
|
52
|
-
"""
|
|
53
|
-
if len(names) == 1:
|
|
54
|
-
setattr(obj, names[0], val)
|
|
55
|
-
else:
|
|
56
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
def get_attr(obj, names: List[str]):
|
|
60
|
-
"""
|
|
61
|
-
Gets an attribute of an object recursively.
|
|
62
|
-
|
|
63
|
-
Args:
|
|
64
|
-
obj (object): Object to get attribute of.
|
|
65
|
-
names (list): List of attribute names to get recursively.
|
|
66
|
-
|
|
67
|
-
Returns:
|
|
68
|
-
object: The attribute of the object.
|
|
69
|
-
"""
|
|
70
|
-
if len(names) == 1:
|
|
71
|
-
return getattr(obj, names[0])
|
|
72
|
-
else:
|
|
73
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
74
|
-
|
|
75
|
-
|
|
76
24
|
class Depth_0_Gate(nn.Module):
|
|
77
25
|
def __init__(self, num_experts: int):
|
|
78
26
|
super().__init__()
|
fusion_bench/models/utils.py
CHANGED
|
@@ -3,6 +3,8 @@ from typing import List
|
|
|
3
3
|
import torch
|
|
4
4
|
from torch import nn
|
|
5
5
|
|
|
6
|
+
from fusion_bench.utils.type import StateDictType
|
|
7
|
+
|
|
6
8
|
|
|
7
9
|
def del_attr(obj, names: List[str]):
|
|
8
10
|
"""
|
|
@@ -50,6 +52,30 @@ def get_attr(obj, names: List[str]):
|
|
|
50
52
|
return get_attr(getattr(obj, names[0]), names[1:])
|
|
51
53
|
|
|
52
54
|
|
|
55
|
+
def check_parameterNamesMatch(checkpoints: List[StateDictType]) -> None:
|
|
56
|
+
"""
|
|
57
|
+
Checks that the parameter names of the given checkpoints match.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
checkpoints (List[Dict[str, float]]): A list of checkpoints, where each checkpoint is a dictionary of parameter names and their corresponding values.
|
|
61
|
+
|
|
62
|
+
Raises:
|
|
63
|
+
ValueError: If the number of checkpoints is less than 2 or if the parameter names of any two checkpoints differ.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
parameter_names = set(checkpoints[0].keys())
|
|
67
|
+
|
|
68
|
+
if len(checkpoints) >= 2:
|
|
69
|
+
# raise ValueError("Number of models is less than 2.")
|
|
70
|
+
for checkpoint in checkpoints[1:]:
|
|
71
|
+
current_parameterNames = set(checkpoint.keys())
|
|
72
|
+
if current_parameterNames != parameter_names:
|
|
73
|
+
raise ValueError(
|
|
74
|
+
"Differing parameter names in models. "
|
|
75
|
+
f"The different parameters are {parameter_names.symmetric_difference(current_parameterNames)}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
53
79
|
def find_layers_with_type(
|
|
54
80
|
module: nn.Module,
|
|
55
81
|
layer_types=[nn.Linear],
|
fusion_bench/models/we_moe.py
CHANGED
|
@@ -8,64 +8,12 @@ from torch import Tensor, nn
|
|
|
8
8
|
from torch.func import functional_call
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
|
|
11
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
11
12
|
from fusion_bench.utils.type import StateDictType
|
|
12
13
|
|
|
13
14
|
log = logging.getLogger(__name__)
|
|
14
15
|
|
|
15
16
|
|
|
16
|
-
def join_list(list_of_list: List[List]):
|
|
17
|
-
ans = []
|
|
18
|
-
for l in list_of_list:
|
|
19
|
-
ans.extend(l)
|
|
20
|
-
return ans
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def del_attr(obj, names: List[str]):
|
|
24
|
-
"""
|
|
25
|
-
Deletes an attribute from an object recursively.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
obj (object): Object to delete attribute from.
|
|
29
|
-
names (list): List of attribute names to delete recursively.
|
|
30
|
-
"""
|
|
31
|
-
if len(names) == 1:
|
|
32
|
-
delattr(obj, names[0])
|
|
33
|
-
else:
|
|
34
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def set_attr(obj, names: List[str], val):
|
|
38
|
-
"""
|
|
39
|
-
Sets an attribute of an object recursively.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
obj (object): Object to set attribute of.
|
|
43
|
-
names (list): List of attribute names to set recursively.
|
|
44
|
-
val (object): Value to set the attribute to.
|
|
45
|
-
"""
|
|
46
|
-
if len(names) == 1:
|
|
47
|
-
setattr(obj, names[0], val)
|
|
48
|
-
else:
|
|
49
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def get_attr(obj, names: List[str]):
|
|
53
|
-
"""
|
|
54
|
-
Gets an attribute of an object recursively.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
obj (object): Object to get attribute of.
|
|
58
|
-
names (list): List of attribute names to get recursively.
|
|
59
|
-
|
|
60
|
-
Returns:
|
|
61
|
-
object: The attribute of the object.
|
|
62
|
-
"""
|
|
63
|
-
if len(names) == 1:
|
|
64
|
-
return getattr(obj, names[0])
|
|
65
|
-
else:
|
|
66
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
67
|
-
|
|
68
|
-
|
|
69
17
|
class Depth_0_Gate(nn.Module):
|
|
70
18
|
def __init__(self, num_experts: int):
|
|
71
19
|
super().__init__()
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Callable, Dict, List, cast
|
|
1
|
+
from typing import Any, Callable, Dict, List, Union, cast
|
|
2
2
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import torch
|
|
@@ -6,7 +6,9 @@ from omegaconf import ListConfig
|
|
|
6
6
|
from torch import Tensor, nn
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def aggregate_tensors(
|
|
9
|
+
def aggregate_tensors(
|
|
10
|
+
outputs: List[Any], aggregate_fn: Callable
|
|
11
|
+
) -> Union[Tensor, Dict, List, None]:
|
|
10
12
|
"""
|
|
11
13
|
Aggregates a list of outputs using the provided aggregation function.
|
|
12
14
|
|
|
@@ -84,7 +86,7 @@ class EnsembleModule(nn.Module):
|
|
|
84
86
|
"""
|
|
85
87
|
return torch.stack(outputs).mean(dim=0)
|
|
86
88
|
|
|
87
|
-
def forward(self, *args, **kwargs):
|
|
89
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
88
90
|
"""
|
|
89
91
|
Performs a forward pass by averaging the outputs of the models.
|
|
90
92
|
|
|
@@ -150,7 +152,7 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
150
152
|
weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
|
|
151
153
|
return (torch.stack(outputs) * weights).sum(dim=0)
|
|
152
154
|
|
|
153
|
-
def forward(self, *args, **kwargs):
|
|
155
|
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
154
156
|
"""
|
|
155
157
|
Performs a forward pass by computing the weighted average of the models' outputs.
|
|
156
158
|
|
|
@@ -49,7 +49,7 @@ def get_layer_wise_weights(
|
|
|
49
49
|
return torch.full((num_models, num_layers), init_values, dtype=dtype)
|
|
50
50
|
|
|
51
51
|
|
|
52
|
-
def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]):
|
|
52
|
+
def _fuse_weights(layer_wise_weight: Tensor, tensors: List[Tensor]) -> Tensor:
|
|
53
53
|
"""
|
|
54
54
|
Fuse the layer-wise weights with the given state dictionaries.
|
|
55
55
|
|