fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
|
6
|
+
|
|
7
|
+
import hydra.core.global_hydra
|
|
8
|
+
from hydra import compose, initialize
|
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils import import_object, instantiate
|
|
12
|
+
from fusion_bench.utils.instantiate import set_print_function_call
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HydraConfigMixin:
|
|
18
|
+
"""
|
|
19
|
+
A mixin for classes that need to be instantiated from a config file.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def from_config(
|
|
24
|
+
cls,
|
|
25
|
+
config_name: Union[str, Path],
|
|
26
|
+
overrides: Optional[List[str]] = None,
|
|
27
|
+
):
|
|
28
|
+
if not hydra.core.global_hydra.GlobalHydra.instance().is_initialized():
|
|
29
|
+
raise RuntimeError("Hydra is not initialized.")
|
|
30
|
+
else:
|
|
31
|
+
cfg = compose(config_name=config_name, overrides=overrides)
|
|
32
|
+
|
|
33
|
+
config_groups = config_name.split("/")[:-1]
|
|
34
|
+
for config_group in config_groups:
|
|
35
|
+
cfg = cfg[config_group]
|
|
36
|
+
|
|
37
|
+
if "_target_" in cfg:
|
|
38
|
+
# if the config has a _target_ key, check if it is equal to the class name
|
|
39
|
+
target_cls = import_object(cfg["_target_"])
|
|
40
|
+
if target_cls != cls:
|
|
41
|
+
log.warning(
|
|
42
|
+
f"The _target_ key in the config is {cfg['_target_']}, but the class name is {cls.__name__}."
|
|
43
|
+
)
|
|
44
|
+
with set_print_function_call(False):
|
|
45
|
+
obj = instantiate(cfg)
|
|
46
|
+
else:
|
|
47
|
+
obj = cls(**cfg)
|
|
48
|
+
|
|
49
|
+
return obj
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
4
|
+
from fusion_bench.models.open_clip import ImageClassifier, ImageEncoder
|
|
5
|
+
|
|
6
|
+
log = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class OpenCLIPClassificationMixin(LightningFabricMixin):
|
|
10
|
+
_train_processor = None
|
|
11
|
+
_test_processor = None
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from contextlib import contextmanager
|
|
2
|
-
from typing import Generator
|
|
2
|
+
from typing import Generator, Optional
|
|
3
3
|
|
|
4
4
|
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
|
5
5
|
from lightning.pytorch.profilers import SimpleProfiler
|
|
@@ -70,7 +70,9 @@ class SimpleProfilerMixin:
|
|
|
70
70
|
self.profiler.stop(action_name)
|
|
71
71
|
|
|
72
72
|
@rank_zero_only
|
|
73
|
-
def print_profile_summary(self):
|
|
73
|
+
def print_profile_summary(self, title: Optional[str] = None):
|
|
74
|
+
if title is not None:
|
|
75
|
+
print(title)
|
|
74
76
|
print(self.profiler.summary())
|
|
75
77
|
|
|
76
78
|
def __del__(self):
|
|
@@ -6,12 +6,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
|
|
|
6
6
|
|
|
7
7
|
_import_structure = {
|
|
8
8
|
"base_pool": ["BaseModelPool"],
|
|
9
|
+
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
9
10
|
"clip_vision": ["CLIPVisionModelPool"],
|
|
10
11
|
"nyuv2_modelpool": ["NYUv2ModelPool"],
|
|
11
12
|
"huggingface_automodel": ["AutoModelPool"],
|
|
12
|
-
"causal_lm": ["CausalLMPool", "CausalLMBackbonePool"],
|
|
13
13
|
"seq2seq_lm": ["Seq2SeqLMPool"],
|
|
14
14
|
"PeftModelForSeq2SeqLM": ["PeftModelForSeq2SeqLMPool"],
|
|
15
|
+
"openclip_vision": ["OpenCLIPVisionModelPool"],
|
|
15
16
|
"huggingface_gpt2_classification": [
|
|
16
17
|
"HuggingFaceGPT2ClassificationPool",
|
|
17
18
|
"GPT2ForSequenceClassificationPool",
|
|
@@ -30,6 +31,7 @@ if TYPE_CHECKING:
|
|
|
30
31
|
HuggingFaceGPT2ClassificationPool,
|
|
31
32
|
)
|
|
32
33
|
from .nyuv2_modelpool import NYUv2ModelPool
|
|
34
|
+
from .openclip_vision import OpenCLIPVisionModelPool
|
|
33
35
|
from .PeftModelForSeq2SeqLM import PeftModelForSeq2SeqLMPool
|
|
34
36
|
from .seq2seq_lm import Seq2SeqLMPool
|
|
35
37
|
from .seq_classification_lm import SeqenceClassificationModelPool
|
|
@@ -7,7 +7,7 @@ from omegaconf import DictConfig
|
|
|
7
7
|
from torch import nn
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
9
|
|
|
10
|
-
from fusion_bench.mixins import BaseYAMLSerializableModel
|
|
10
|
+
from fusion_bench.mixins import BaseYAMLSerializableModel, HydraConfigMixin
|
|
11
11
|
from fusion_bench.utils import instantiate, timeit_context
|
|
12
12
|
|
|
13
13
|
__all__ = ["BaseModelPool"]
|
|
@@ -15,7 +15,7 @@ __all__ = ["BaseModelPool"]
|
|
|
15
15
|
log = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
class BaseModelPool(BaseYAMLSerializableModel):
|
|
18
|
+
class BaseModelPool(BaseYAMLSerializableModel, HydraConfigMixin):
|
|
19
19
|
"""
|
|
20
20
|
A class for managing and interacting with a pool of models along with their associated datasets or other specifications. For example, a model pool may contain multiple models, each with its own training, validation, and testing datasets. As for the specifications, a vision model pool may contain image preprocessor, and a language model pool may contain a tokenizer.
|
|
21
21
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .modelpool import OpenCLIPVisionModelPool
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pickle
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Callable, Optional, Union, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from datasets import load_dataset
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
12
|
+
from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
|
|
13
|
+
from fusion_bench.utils import instantiate
|
|
14
|
+
from fusion_bench.utils.expr import is_expr_match
|
|
15
|
+
from fusion_bench.utils.packages import _get_package_version, compare_versions
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Add flag to track if warning has been shown
|
|
20
|
+
_openclip_version_warning_shown = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _check_and_redirect_open_clip_modeling():
|
|
24
|
+
global _openclip_version_warning_shown
|
|
25
|
+
if compare_versions(_get_package_version("open-clip-torch").__str__(), "2.0.2") > 0:
|
|
26
|
+
if not _openclip_version_warning_shown:
|
|
27
|
+
log.warning(
|
|
28
|
+
"OpenCLIP version is greater than 2.0.2. This may cause issues with the modelpool."
|
|
29
|
+
)
|
|
30
|
+
_openclip_version_warning_shown = True
|
|
31
|
+
import open_clip.model
|
|
32
|
+
import open_clip.transformer
|
|
33
|
+
|
|
34
|
+
if not hasattr(open_clip.model, "VisualTransformer"):
|
|
35
|
+
open_clip.model.VisualTransformer = open_clip.model.VisionTransformer
|
|
36
|
+
if not hasattr(open_clip.model, "Transformer"):
|
|
37
|
+
open_clip.model.Transformer = open_clip.transformer.Transformer
|
|
38
|
+
if not hasattr(open_clip.model, "ResidualAttentionBlock"):
|
|
39
|
+
open_clip.model.ResidualAttentionBlock = (
|
|
40
|
+
open_clip.transformer.ResidualAttentionBlock
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import src
|
|
45
|
+
import src.modeling
|
|
46
|
+
except ImportError:
|
|
47
|
+
if "src" not in sys.modules:
|
|
48
|
+
# redirect the import of `src` to `fusion_bench.models.open_clip`
|
|
49
|
+
import fusion_bench.models.open_clip as open_clip
|
|
50
|
+
|
|
51
|
+
sys.modules["src"] = open_clip
|
|
52
|
+
log.warning(
|
|
53
|
+
"`src` is not imported."
|
|
54
|
+
"Redirecting the import to `fusion_bench.models.open_clip`"
|
|
55
|
+
)
|
|
56
|
+
if "src.modeling" not in sys.modules:
|
|
57
|
+
# redirect the import of `src.modeling` to `fusion_bench.models.open_clip.modeling`
|
|
58
|
+
import fusion_bench.models.open_clip.modeling as open_clip_modeling
|
|
59
|
+
|
|
60
|
+
sys.modules["src.modeling"] = open_clip_modeling
|
|
61
|
+
log.warning(
|
|
62
|
+
"`src.modeling` is not imported."
|
|
63
|
+
"Redirecting the import to `fusion_bench.models.open_clip.modeling`"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def load_classifier_head(model_config: Union[str, DictConfig], *args, **kwargs):
|
|
68
|
+
if isinstance(model_config, str):
|
|
69
|
+
_check_and_redirect_open_clip_modeling()
|
|
70
|
+
log.info(f"Loading `ClassificationHead` from {model_config}")
|
|
71
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
72
|
+
head = torch.load(model_config, weights_only=weights_only, *args, **kwargs)
|
|
73
|
+
elif isinstance(model_config, nn.Module):
|
|
74
|
+
log.info(f"Returning existing model: {model_config}")
|
|
75
|
+
head = model_config
|
|
76
|
+
else:
|
|
77
|
+
head = instantiate(model_config, *args, **kwargs)
|
|
78
|
+
head = cast(ClassificationHead, head)
|
|
79
|
+
return head
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class OpenCLIPVisionModelPool(BaseModelPool):
|
|
83
|
+
"""
|
|
84
|
+
A model pool for managing OpenCLIP Vision models (models from task vector paper).
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
_train_processor = None
|
|
88
|
+
_test_processor = None
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
models: DictConfig,
|
|
93
|
+
classification_heads: Optional[DictConfig] = None,
|
|
94
|
+
**kwargs,
|
|
95
|
+
):
|
|
96
|
+
super().__init__(models, **kwargs)
|
|
97
|
+
self._classification_heads = classification_heads
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def train_processor(self):
|
|
101
|
+
if self._train_processor is None:
|
|
102
|
+
encoder: ImageEncoder = self.load_pretrained_or_first_model()
|
|
103
|
+
self._train_processor = encoder.train_preprocess
|
|
104
|
+
if self._test_processor is None:
|
|
105
|
+
self._test_processor = encoder.val_preprocess
|
|
106
|
+
return self._train_processor
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def test_processor(self):
|
|
110
|
+
if self._test_processor is None:
|
|
111
|
+
encoder: ImageEncoder = self.load_pretrained_or_first_model()
|
|
112
|
+
if self._train_processor is None:
|
|
113
|
+
self._train_processor = encoder.train_preprocess
|
|
114
|
+
self._test_processor = encoder.val_preprocess
|
|
115
|
+
return self._test_processor
|
|
116
|
+
|
|
117
|
+
def load_model(
|
|
118
|
+
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
119
|
+
) -> ImageEncoder:
|
|
120
|
+
R"""
|
|
121
|
+
The model config can be:
|
|
122
|
+
|
|
123
|
+
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
124
|
+
- {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
|
|
125
|
+
- {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
|
|
126
|
+
- Default, load the model using `instantiate` from hydra.
|
|
127
|
+
"""
|
|
128
|
+
if (
|
|
129
|
+
isinstance(model_name_or_config, str)
|
|
130
|
+
and model_name_or_config in self._models
|
|
131
|
+
):
|
|
132
|
+
model_config = self._models[model_name_or_config]
|
|
133
|
+
else:
|
|
134
|
+
model_config = model_name_or_config
|
|
135
|
+
if isinstance(model_config, DictConfig):
|
|
136
|
+
model_config = OmegaConf.to_container(model_config, resolve=True)
|
|
137
|
+
|
|
138
|
+
if isinstance(model_config, str):
|
|
139
|
+
# the model config is a string, which is the path to the model checkpoint in pickle format
|
|
140
|
+
# load the model using `torch.load`
|
|
141
|
+
# this is the original usage in the task arithmetic codebase
|
|
142
|
+
_check_and_redirect_open_clip_modeling()
|
|
143
|
+
log.info(f"loading ImageEncoder from {model_config}")
|
|
144
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
145
|
+
try:
|
|
146
|
+
encoder = torch.load(
|
|
147
|
+
model_config, weights_only=weights_only, *args, **kwargs
|
|
148
|
+
)
|
|
149
|
+
except RuntimeError as e:
|
|
150
|
+
encoder = pickle.load(open(model_config, "rb"))
|
|
151
|
+
elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
|
|
152
|
+
# the model config is a dictionary with the following keys:
|
|
153
|
+
# - model_name: str, the name of the model
|
|
154
|
+
# - pickle_path: str, the path to the binary file (pickle format)
|
|
155
|
+
# load the model from the binary file (pickle format)
|
|
156
|
+
# this is useful when you use a newer version of torchvision
|
|
157
|
+
_check_and_redirect_open_clip_modeling()
|
|
158
|
+
log.info(
|
|
159
|
+
f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
|
|
160
|
+
)
|
|
161
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
162
|
+
try:
|
|
163
|
+
encoder = torch.load(
|
|
164
|
+
model_config["pickle_path"],
|
|
165
|
+
weights_only=weights_only,
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
except RuntimeError as e:
|
|
170
|
+
encoder = pickle.load(open(model_config["pickle_path"], "rb"))
|
|
171
|
+
_encoder = ImageEncoder(model_config["model_name"])
|
|
172
|
+
_encoder.load_state_dict(encoder.state_dict())
|
|
173
|
+
encoder = _encoder
|
|
174
|
+
elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
|
|
175
|
+
# the model config is a dictionary with the following keys:
|
|
176
|
+
# - model_name: str, the name of the model
|
|
177
|
+
# - state_dict_path: str, the path to the state dict file
|
|
178
|
+
# load the model from the state dict file
|
|
179
|
+
log.info(
|
|
180
|
+
f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
|
|
181
|
+
)
|
|
182
|
+
encoder = ImageEncoder(model_config["model_name"])
|
|
183
|
+
encoder.load_state_dict(
|
|
184
|
+
torch.load(
|
|
185
|
+
model_config["state_dict_path"], weights_only=True, *args, **kwargs
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
elif isinstance(model_config, nn.Module):
|
|
189
|
+
# the model config is an existing model
|
|
190
|
+
log.info(f"Returning existing model: {model_config}")
|
|
191
|
+
encoder = model_config
|
|
192
|
+
else:
|
|
193
|
+
encoder = super().load_model(model_name_or_config, *args, **kwargs)
|
|
194
|
+
encoder = cast(ImageEncoder, encoder)
|
|
195
|
+
|
|
196
|
+
# setup the train and test processors
|
|
197
|
+
if self._train_processor is None and hasattr(encoder, "train_preprocess"):
|
|
198
|
+
self._train_processor = encoder.train_preprocess
|
|
199
|
+
if self._test_processor is None and hasattr(encoder, "val_preprocess"):
|
|
200
|
+
self._test_processor = encoder.val_preprocess
|
|
201
|
+
|
|
202
|
+
return encoder
|
|
203
|
+
|
|
204
|
+
def load_classification_head(
|
|
205
|
+
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
206
|
+
) -> ClassificationHead:
|
|
207
|
+
R"""
|
|
208
|
+
The model config can be:
|
|
209
|
+
|
|
210
|
+
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
211
|
+
- Default, load the model using `instantiate` from hydra.
|
|
212
|
+
"""
|
|
213
|
+
if (
|
|
214
|
+
isinstance(model_name_or_config, str)
|
|
215
|
+
and model_name_or_config in self._classification_heads
|
|
216
|
+
):
|
|
217
|
+
model_config = self._classification_heads[model_name_or_config]
|
|
218
|
+
else:
|
|
219
|
+
model_config = model_name_or_config
|
|
220
|
+
|
|
221
|
+
head = load_classifier_head(model_config, *args, **kwargs)
|
|
222
|
+
return head
|
|
223
|
+
|
|
224
|
+
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
225
|
+
dataset_config = self._train_datasets[dataset_name]
|
|
226
|
+
if isinstance(dataset_config, str):
|
|
227
|
+
log.info(
|
|
228
|
+
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
229
|
+
)
|
|
230
|
+
dataset = load_dataset(dataset_config, split="train")
|
|
231
|
+
else:
|
|
232
|
+
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
233
|
+
return dataset
|
|
234
|
+
|
|
235
|
+
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
236
|
+
dataset_config = self._val_datasets[dataset_name]
|
|
237
|
+
if isinstance(dataset_config, str):
|
|
238
|
+
log.info(
|
|
239
|
+
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
240
|
+
)
|
|
241
|
+
dataset = load_dataset(dataset_config, split="validation")
|
|
242
|
+
else:
|
|
243
|
+
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
244
|
+
return dataset
|
|
245
|
+
|
|
246
|
+
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
247
|
+
dataset_config = self._test_datasets[dataset_name]
|
|
248
|
+
if isinstance(dataset_config, str):
|
|
249
|
+
log.info(
|
|
250
|
+
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
251
|
+
)
|
|
252
|
+
dataset = load_dataset(dataset_config, split="test")
|
|
253
|
+
else:
|
|
254
|
+
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
255
|
+
return dataset
|
|
@@ -27,6 +27,8 @@ from transformers.models.mistral.modeling_mistral import (
|
|
|
27
27
|
MistralRotaryEmbedding,
|
|
28
28
|
)
|
|
29
29
|
|
|
30
|
+
from fusion_bench.models.smile_moe.linear_from_hf_config import SmileLinear
|
|
31
|
+
|
|
30
32
|
from .configuration_smile_mistral import SmileMistralConfig
|
|
31
33
|
|
|
32
34
|
logger = logging.getLogger(__name__)
|
|
@@ -80,209 +82,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
|
80
82
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
81
83
|
|
|
82
84
|
|
|
83
|
-
class SmileGate(nn.Module):
|
|
84
|
-
__constants__ = ["in_features", "num_experts", "k"]
|
|
85
|
-
in_features: int
|
|
86
|
-
num_experts: int
|
|
87
|
-
k: int
|
|
88
|
-
weight: Tensor
|
|
89
|
-
|
|
90
|
-
def __init__(
|
|
91
|
-
self,
|
|
92
|
-
in_features: int,
|
|
93
|
-
num_experts: int,
|
|
94
|
-
k: int,
|
|
95
|
-
device=None,
|
|
96
|
-
dtype=None,
|
|
97
|
-
):
|
|
98
|
-
factory_kwargs = {"device": device, "dtype": dtype}
|
|
99
|
-
super().__init__()
|
|
100
|
-
self.input_features = in_features
|
|
101
|
-
self.num_experts = num_experts
|
|
102
|
-
self.k = k
|
|
103
|
-
|
|
104
|
-
self.weight = nn.Parameter(
|
|
105
|
-
torch.empty(num_experts * k, in_features, **factory_kwargs)
|
|
106
|
-
)
|
|
107
|
-
|
|
108
|
-
def forward(self, x: Tensor):
|
|
109
|
-
batch_size = x.size(0)
|
|
110
|
-
if self.num_experts == 1:
|
|
111
|
-
return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
|
|
112
|
-
|
|
113
|
-
routing_weights = F.linear(x, self.weight).view(
|
|
114
|
-
batch_size, self.num_experts, self.k
|
|
115
|
-
)
|
|
116
|
-
routing_weights = routing_weights.norm(p=2, dim=2)
|
|
117
|
-
return routing_weights
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
class SmileLinearExpert(nn.Module):
|
|
121
|
-
__constants__ = ["in_features", "out_features", "k"]
|
|
122
|
-
in_features: int
|
|
123
|
-
out_features: int
|
|
124
|
-
k: int
|
|
125
|
-
|
|
126
|
-
def __init__(
|
|
127
|
-
self,
|
|
128
|
-
in_features,
|
|
129
|
-
out_features,
|
|
130
|
-
k: int,
|
|
131
|
-
bias: bool,
|
|
132
|
-
device=None,
|
|
133
|
-
dtype=None,
|
|
134
|
-
):
|
|
135
|
-
factory_kwargs = {"device": device, "dtype": dtype}
|
|
136
|
-
super().__init__()
|
|
137
|
-
self.in_features = in_features
|
|
138
|
-
self.out_features = out_features
|
|
139
|
-
self.k = k
|
|
140
|
-
|
|
141
|
-
self.u = nn.Parameter(torch.empty(out_features, k, **factory_kwargs))
|
|
142
|
-
self.svh = nn.Parameter(torch.empty(k, in_features, **factory_kwargs))
|
|
143
|
-
|
|
144
|
-
if bias:
|
|
145
|
-
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
|
146
|
-
else:
|
|
147
|
-
self.register_parameter("bias", None)
|
|
148
|
-
|
|
149
|
-
def forward(self, x):
|
|
150
|
-
x = F.linear(x, self.svh)
|
|
151
|
-
x = F.linear(x, self.u, self.bias)
|
|
152
|
-
return x
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
class SmileLinear(nn.Module):
|
|
156
|
-
@torch.no_grad()
|
|
157
|
-
def __init__(
|
|
158
|
-
self,
|
|
159
|
-
config: SmileMistralConfig,
|
|
160
|
-
in_features,
|
|
161
|
-
out_features,
|
|
162
|
-
bias: bool,
|
|
163
|
-
device=None,
|
|
164
|
-
dtype=None,
|
|
165
|
-
):
|
|
166
|
-
factory_kwargs = {"device": device, "dtype": dtype}
|
|
167
|
-
super().__init__()
|
|
168
|
-
self.num_local_experts = config.num_local_experts
|
|
169
|
-
self.num_experts_per_tok = config.num_experts_per_tok
|
|
170
|
-
self.rank_of_expert = config.rank_of_expert
|
|
171
|
-
self.rank_of_router = config.rank_of_router
|
|
172
|
-
self.in_features = in_features
|
|
173
|
-
self.out_features = out_features
|
|
174
|
-
|
|
175
|
-
# construct the gate network
|
|
176
|
-
self.gate = SmileGate(
|
|
177
|
-
in_features=in_features,
|
|
178
|
-
num_experts=self.num_local_experts,
|
|
179
|
-
k=self.rank_of_router,
|
|
180
|
-
**factory_kwargs,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
# the shared linear
|
|
184
|
-
self.shared_linear = nn.Linear(
|
|
185
|
-
in_features, out_features, bias=bias, **factory_kwargs
|
|
186
|
-
)
|
|
187
|
-
|
|
188
|
-
# construct experts
|
|
189
|
-
if self.rank_of_expert > 0:
|
|
190
|
-
self.experts = nn.ModuleList(
|
|
191
|
-
[
|
|
192
|
-
SmileLinearExpert(
|
|
193
|
-
in_features=in_features,
|
|
194
|
-
out_features=out_features,
|
|
195
|
-
bias=bias,
|
|
196
|
-
k=self.rank_of_expert,
|
|
197
|
-
**factory_kwargs,
|
|
198
|
-
)
|
|
199
|
-
for _ in range(self.num_local_experts)
|
|
200
|
-
]
|
|
201
|
-
)
|
|
202
|
-
else:
|
|
203
|
-
self.experts = nn.ModuleList(
|
|
204
|
-
[
|
|
205
|
-
nn.Linear(in_features, out_features, bias=bias, **factory_kwargs)
|
|
206
|
-
for _ in range(self.num_local_experts)
|
|
207
|
-
]
|
|
208
|
-
)
|
|
209
|
-
|
|
210
|
-
def forward(self, hidden_states: Tensor):
|
|
211
|
-
pretrained_out = self.shared_linear(hidden_states)
|
|
212
|
-
|
|
213
|
-
input_shape = hidden_states.size()
|
|
214
|
-
hidden_states = hidden_states.view(-1, self.in_features)
|
|
215
|
-
|
|
216
|
-
router_logits = self.gate(hidden_states)
|
|
217
|
-
routing_weights = F.softmax(router_logits, dim=1)
|
|
218
|
-
# sample the expert according to the routing weights
|
|
219
|
-
routing_weights, selected_experts = torch.topk(
|
|
220
|
-
routing_weights, self.num_experts_per_tok, dim=-1
|
|
221
|
-
)
|
|
222
|
-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
223
|
-
|
|
224
|
-
final_hidden_states = torch.zeros(
|
|
225
|
-
(hidden_states.size(0), self.out_features),
|
|
226
|
-
dtype=hidden_states.dtype,
|
|
227
|
-
device=hidden_states.device,
|
|
228
|
-
)
|
|
229
|
-
|
|
230
|
-
# One hot encode the selected experts to create an expert mask
|
|
231
|
-
# this will be used to easily index which expert is going to be sollicitated
|
|
232
|
-
expert_mask = torch.nn.functional.one_hot(
|
|
233
|
-
selected_experts, num_classes=self.num_local_experts
|
|
234
|
-
).permute(2, 1, 0)
|
|
235
|
-
|
|
236
|
-
# Loop over all available experts in the model and perform the computation on each expert
|
|
237
|
-
for expert_idx in range(self.num_local_experts):
|
|
238
|
-
expert_layer = self.experts[expert_idx]
|
|
239
|
-
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
240
|
-
|
|
241
|
-
# Index the correct hidden states and compute the expert hidden state for
|
|
242
|
-
# the current expert. We need to make sure to multiply the output hidden
|
|
243
|
-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
244
|
-
current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
|
|
245
|
-
if current_state.numel() == 0:
|
|
246
|
-
continue
|
|
247
|
-
current_hidden_states = (
|
|
248
|
-
expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
252
|
-
# the `top_x` tensor here.
|
|
253
|
-
final_hidden_states.index_add_(
|
|
254
|
-
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
255
|
-
)
|
|
256
|
-
final_hidden_states = final_hidden_states.reshape(
|
|
257
|
-
*input_shape[:-1], self.out_features
|
|
258
|
-
)
|
|
259
|
-
final_hidden_states = pretrained_out + final_hidden_states
|
|
260
|
-
return final_hidden_states
|
|
261
|
-
|
|
262
|
-
@property
|
|
263
|
-
def weight(self):
|
|
264
|
-
"""
|
|
265
|
-
Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
|
|
266
|
-
"""
|
|
267
|
-
return self.shared_linear.weight
|
|
268
|
-
|
|
269
|
-
@property
|
|
270
|
-
def bias(self):
|
|
271
|
-
return self.shared_linear.bias
|
|
272
|
-
|
|
273
|
-
def __repr__(self):
|
|
274
|
-
return (
|
|
275
|
-
f"SingularMoELinear("
|
|
276
|
-
f"in_features={self.shared_linear.in_features}, "
|
|
277
|
-
f"out_features={self.shared_linear.out_features}, "
|
|
278
|
-
f"num_local_experts={self.num_local_experts}, "
|
|
279
|
-
f"num_experts_per_tok={self.num_experts_per_tok}, "
|
|
280
|
-
f"rank_of_router={self.rank_of_router}, "
|
|
281
|
-
f"rank_of_expert={self.rank_of_expert}"
|
|
282
|
-
f")"
|
|
283
|
-
)
|
|
284
|
-
|
|
285
|
-
|
|
286
85
|
class SmileMistralAttention(nn.Module):
|
|
287
86
|
"""
|
|
288
87
|
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from transformers import PretrainedConfig
|
|
2
|
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class SmileQwen2Config(Qwen2Config):
|
|
6
|
+
model_type = "smile_qwen2"
|
|
7
|
+
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
num_experts_per_tok: int = 1,
|
|
11
|
+
rank_of_router: int = None,
|
|
12
|
+
rank_of_expert: int = None,
|
|
13
|
+
num_local_experts: int = None,
|
|
14
|
+
**kwargs,
|
|
15
|
+
):
|
|
16
|
+
self.num_experts_per_tok = num_experts_per_tok
|
|
17
|
+
self.rank_of_router = rank_of_router
|
|
18
|
+
self.rank_of_expert = rank_of_expert
|
|
19
|
+
self.num_local_experts = num_local_experts
|
|
20
|
+
|
|
21
|
+
super().__init__(**kwargs)
|