fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -8,64 +8,13 @@ from torch import Tensor, nn
|
|
|
8
8
|
from torch.func import functional_call
|
|
9
9
|
from torch.nn import functional as F
|
|
10
10
|
|
|
11
|
+
from fusion_bench.models.smile_moe.utils import _is_all_zeros, svd
|
|
12
|
+
from fusion_bench.models.utils import del_attr, get_attr, set_attr
|
|
11
13
|
from fusion_bench.utils.type import StateDictType
|
|
12
14
|
|
|
13
15
|
log = logging.getLogger(__name__)
|
|
14
16
|
|
|
15
17
|
|
|
16
|
-
def join_list(list_of_list: List[List]):
|
|
17
|
-
ans = []
|
|
18
|
-
for l in list_of_list:
|
|
19
|
-
ans.extend(l)
|
|
20
|
-
return ans
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def del_attr(obj, names: List[str]):
|
|
24
|
-
"""
|
|
25
|
-
Deletes an attribute from an object recursively.
|
|
26
|
-
|
|
27
|
-
Args:
|
|
28
|
-
obj (object): Object to delete attribute from.
|
|
29
|
-
names (list): List of attribute names to delete recursively.
|
|
30
|
-
"""
|
|
31
|
-
if len(names) == 1:
|
|
32
|
-
delattr(obj, names[0])
|
|
33
|
-
else:
|
|
34
|
-
del_attr(getattr(obj, names[0]), names[1:])
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def set_attr(obj, names: List[str], val):
|
|
38
|
-
"""
|
|
39
|
-
Sets an attribute of an object recursively.
|
|
40
|
-
|
|
41
|
-
Args:
|
|
42
|
-
obj (object): Object to set attribute of.
|
|
43
|
-
names (list): List of attribute names to set recursively.
|
|
44
|
-
val (object): Value to set the attribute to.
|
|
45
|
-
"""
|
|
46
|
-
if len(names) == 1:
|
|
47
|
-
setattr(obj, names[0], val)
|
|
48
|
-
else:
|
|
49
|
-
set_attr(getattr(obj, names[0]), names[1:], val)
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
def get_attr(obj, names: List[str]):
|
|
53
|
-
"""
|
|
54
|
-
Gets an attribute of an object recursively.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
obj (object): Object to get attribute of.
|
|
58
|
-
names (list): List of attribute names to get recursively.
|
|
59
|
-
|
|
60
|
-
Returns:
|
|
61
|
-
object: The attribute of the object.
|
|
62
|
-
"""
|
|
63
|
-
if len(names) == 1:
|
|
64
|
-
return getattr(obj, names[0])
|
|
65
|
-
else:
|
|
66
|
-
return get_attr(getattr(obj, names[0]), names[1:])
|
|
67
|
-
|
|
68
|
-
|
|
69
18
|
class Depth_0_Gate(nn.Module):
|
|
70
19
|
def __init__(self, num_experts: int):
|
|
71
20
|
super().__init__()
|
|
@@ -132,41 +81,6 @@ class ExpertNotTrainedError(Exception):
|
|
|
132
81
|
pass
|
|
133
82
|
|
|
134
83
|
|
|
135
|
-
def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
|
|
136
|
-
"""
|
|
137
|
-
Check if a tensor or a list of tensors are all zeros.
|
|
138
|
-
"""
|
|
139
|
-
if isinstance(tensor, Tensor):
|
|
140
|
-
return torch.allclose(tensor, torch.zeros_like(tensor))
|
|
141
|
-
else:
|
|
142
|
-
return all(_is_all_zeros(t) for t in tensor)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
146
|
-
"""
|
|
147
|
-
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
148
|
-
"""
|
|
149
|
-
u, s, vh = torch.linalg.svd(
|
|
150
|
-
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
151
|
-
)
|
|
152
|
-
v = vh.T
|
|
153
|
-
return u, s, v
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
def svd(
|
|
157
|
-
w: Tensor, full_matrices=True, accelerator=None
|
|
158
|
-
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
159
|
-
"""
|
|
160
|
-
Perform SVD on a tensor, optionally using a specified accelerator.
|
|
161
|
-
"""
|
|
162
|
-
if accelerator is None:
|
|
163
|
-
return _svd(w, full_matrices=full_matrices)
|
|
164
|
-
original_device = w.device
|
|
165
|
-
w = w.to(accelerator)
|
|
166
|
-
u, s, v = _svd(w)
|
|
167
|
-
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
168
|
-
|
|
169
|
-
|
|
170
84
|
def fun_joint_svd(
|
|
171
85
|
w_list: List[Tensor], accelerator=None
|
|
172
86
|
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
from typing import List, Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch import Tensor, nn
|
|
6
|
+
|
|
7
|
+
from fusion_bench.utils.state_dict_arithmetic import state_dict_sub
|
|
8
|
+
|
|
9
|
+
from .utils import _is_all_zeros
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ExpertNotTrainedError(Exception):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _svd(w: Tensor, full_matrices=False) -> Tuple[Tensor, Tensor, Tensor]:
|
|
17
|
+
"""
|
|
18
|
+
Perform Singular Value Decomposition (SVD) on a tensor.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
w (Tensor): The input tensor.
|
|
22
|
+
full_matrices (bool, optional): Whether to compute the full-sized U and V matrices. Defaults to False.
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
26
|
+
"""
|
|
27
|
+
dtype = w.dtype
|
|
28
|
+
if w.dtype != torch.float32 or w.dtype != torch.float64:
|
|
29
|
+
w = w.float()
|
|
30
|
+
|
|
31
|
+
u, s, vh = torch.linalg.svd(
|
|
32
|
+
w,
|
|
33
|
+
full_matrices=full_matrices,
|
|
34
|
+
# driver="gesvd" if w.is_cuda else None
|
|
35
|
+
)
|
|
36
|
+
v = vh.T
|
|
37
|
+
|
|
38
|
+
u = u.to(dtype=dtype)
|
|
39
|
+
s = s.to(dtype=dtype)
|
|
40
|
+
v = v.to(dtype=dtype)
|
|
41
|
+
return u, s, v
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def svd(
|
|
45
|
+
w: Tensor, full_matrices=True, accelerator=None
|
|
46
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
47
|
+
"""
|
|
48
|
+
Perform SVD on a tensor with optional acceleration.
|
|
49
|
+
This is different from `.utils.svd` in that it handles tensors with precision other than float32 or float64.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
w (Tensor): The input tensor.
|
|
53
|
+
full_matrices (bool, optional): Whether to compute the full-sized U and V matrices. Defaults to True.
|
|
54
|
+
accelerator (optional): The device to perform the computation on. Defaults to None.
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tuple[Tensor, Tensor, Tensor]: The U, S, and V matrices from SVD.
|
|
58
|
+
"""
|
|
59
|
+
if accelerator is None:
|
|
60
|
+
return _svd(w, full_matrices=full_matrices)
|
|
61
|
+
original_device = w.device
|
|
62
|
+
w = w.to(accelerator)
|
|
63
|
+
u, s, v = _svd(w)
|
|
64
|
+
return u, s, v
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class SmileMoEConfig:
|
|
68
|
+
"""
|
|
69
|
+
Example PretrainedConfig for SmileMoE.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
num_experts_per_tok: Number of experts per token.
|
|
73
|
+
rank_of_router: Rank of the router.
|
|
74
|
+
rank_of_expert: Rank of the expert.
|
|
75
|
+
num_local_experts: Number of local experts.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
num_experts_per_tok: int
|
|
79
|
+
rank_of_router: int
|
|
80
|
+
rank_of_expert: int
|
|
81
|
+
num_local_experts: int
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class SmileGate(nn.Module):
|
|
85
|
+
__constants__ = ["in_features", "num_experts", "k"]
|
|
86
|
+
in_features: int
|
|
87
|
+
num_experts: int
|
|
88
|
+
k: int
|
|
89
|
+
weight: nn.Parameter
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
in_features: int,
|
|
94
|
+
num_experts: int,
|
|
95
|
+
k: int,
|
|
96
|
+
device=None,
|
|
97
|
+
dtype=None,
|
|
98
|
+
):
|
|
99
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
100
|
+
super().__init__()
|
|
101
|
+
self.input_features = in_features
|
|
102
|
+
self.num_experts = num_experts
|
|
103
|
+
self.k = k
|
|
104
|
+
|
|
105
|
+
self.weight = nn.Parameter(
|
|
106
|
+
torch.empty(num_experts * k, in_features, **factory_kwargs)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def forward(self, x: Tensor):
|
|
110
|
+
batch_size = x.size(0)
|
|
111
|
+
if self.num_experts == 1:
|
|
112
|
+
return torch.ones(batch_size, 1, device=x.device, dtype=x.dtype)
|
|
113
|
+
|
|
114
|
+
routing_weights = F.linear(x, self.weight).view(
|
|
115
|
+
batch_size, self.num_experts, self.k
|
|
116
|
+
)
|
|
117
|
+
routing_weights = routing_weights.norm(p=2, dim=2)
|
|
118
|
+
return routing_weights
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class SmileLinearExpert(nn.Module):
|
|
122
|
+
__constants__ = ["in_features", "out_features", "k"]
|
|
123
|
+
in_features: int
|
|
124
|
+
out_features: int
|
|
125
|
+
k: int
|
|
126
|
+
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
in_features,
|
|
130
|
+
out_features,
|
|
131
|
+
k: int,
|
|
132
|
+
bias: bool,
|
|
133
|
+
device=None,
|
|
134
|
+
dtype=None,
|
|
135
|
+
):
|
|
136
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
137
|
+
super().__init__()
|
|
138
|
+
self.in_features = in_features
|
|
139
|
+
self.out_features = out_features
|
|
140
|
+
self.k = k
|
|
141
|
+
if k > 0:
|
|
142
|
+
# check k < in_features and out_features
|
|
143
|
+
if k > in_features:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"k ({k}) must not be greater than in_features ({in_features})"
|
|
146
|
+
)
|
|
147
|
+
if k > out_features:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"k ({k}) must not be greater than out_features ({out_features})"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self.u = nn.Parameter(torch.empty(out_features, k, **factory_kwargs))
|
|
153
|
+
self.svh = nn.Parameter(torch.empty(k, in_features, **factory_kwargs))
|
|
154
|
+
|
|
155
|
+
if bias:
|
|
156
|
+
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
|
|
157
|
+
else:
|
|
158
|
+
self.register_parameter("bias", None)
|
|
159
|
+
|
|
160
|
+
def forward(self, x):
|
|
161
|
+
x = F.linear(x, self.svh)
|
|
162
|
+
x = F.linear(x, self.u, self.bias)
|
|
163
|
+
return x
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class SmileLinear(nn.Module):
|
|
167
|
+
__constants__ = [
|
|
168
|
+
"in_features",
|
|
169
|
+
"out_features",
|
|
170
|
+
"num_local_experts",
|
|
171
|
+
"num_experts_per_tok",
|
|
172
|
+
"rank_of_expert",
|
|
173
|
+
"rank_of_router",
|
|
174
|
+
]
|
|
175
|
+
|
|
176
|
+
in_features: int
|
|
177
|
+
out_features: int
|
|
178
|
+
num_local_experts: int
|
|
179
|
+
num_experts_per_tok: int
|
|
180
|
+
rank_of_expert: int
|
|
181
|
+
rank_of_router: int
|
|
182
|
+
|
|
183
|
+
@torch.no_grad()
|
|
184
|
+
def __init__(
|
|
185
|
+
self,
|
|
186
|
+
config: SmileMoEConfig,
|
|
187
|
+
in_features,
|
|
188
|
+
out_features,
|
|
189
|
+
bias: bool,
|
|
190
|
+
device=None,
|
|
191
|
+
dtype=None,
|
|
192
|
+
):
|
|
193
|
+
factory_kwargs = {"device": device, "dtype": dtype}
|
|
194
|
+
super().__init__()
|
|
195
|
+
self.num_local_experts = config.num_local_experts
|
|
196
|
+
self.num_experts_per_tok = config.num_experts_per_tok
|
|
197
|
+
self.rank_of_expert = config.rank_of_expert
|
|
198
|
+
self.rank_of_router = config.rank_of_router
|
|
199
|
+
self.in_features = in_features
|
|
200
|
+
self.out_features = out_features
|
|
201
|
+
|
|
202
|
+
# construct the gate network
|
|
203
|
+
self.gate = SmileGate(
|
|
204
|
+
in_features=in_features,
|
|
205
|
+
num_experts=self.num_local_experts,
|
|
206
|
+
k=self.rank_of_router,
|
|
207
|
+
**factory_kwargs,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# the shared linear
|
|
211
|
+
self.shared_linear = nn.Linear(
|
|
212
|
+
in_features, out_features, bias=bias, **factory_kwargs
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# construct experts
|
|
216
|
+
if self.rank_of_expert > 0:
|
|
217
|
+
self.experts = nn.ModuleList(
|
|
218
|
+
[
|
|
219
|
+
SmileLinearExpert(
|
|
220
|
+
in_features=in_features,
|
|
221
|
+
out_features=out_features,
|
|
222
|
+
bias=bias,
|
|
223
|
+
k=self.rank_of_expert,
|
|
224
|
+
**factory_kwargs,
|
|
225
|
+
)
|
|
226
|
+
for _ in range(self.num_local_experts)
|
|
227
|
+
]
|
|
228
|
+
)
|
|
229
|
+
else:
|
|
230
|
+
self.experts = nn.ModuleList(
|
|
231
|
+
[
|
|
232
|
+
nn.Linear(in_features, out_features, bias=bias, **factory_kwargs)
|
|
233
|
+
for _ in range(self.num_local_experts)
|
|
234
|
+
]
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def forward(self, hidden_states: Tensor):
|
|
238
|
+
pretrained_out = self.shared_linear(hidden_states)
|
|
239
|
+
|
|
240
|
+
input_shape = hidden_states.size()
|
|
241
|
+
hidden_states = hidden_states.view(-1, self.in_features)
|
|
242
|
+
|
|
243
|
+
router_logits = self.gate(hidden_states)
|
|
244
|
+
routing_weights = F.softmax(router_logits, dim=1)
|
|
245
|
+
# sample the expert according to the routing weights
|
|
246
|
+
routing_weights, selected_experts = torch.topk(
|
|
247
|
+
routing_weights, self.num_experts_per_tok, dim=-1
|
|
248
|
+
)
|
|
249
|
+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
|
250
|
+
|
|
251
|
+
final_hidden_states = torch.zeros(
|
|
252
|
+
(hidden_states.size(0), self.out_features),
|
|
253
|
+
dtype=hidden_states.dtype,
|
|
254
|
+
device=hidden_states.device,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# One hot encode the selected experts to create an expert mask
|
|
258
|
+
# this will be used to easily index which expert is going to be sollicitated
|
|
259
|
+
expert_mask = torch.nn.functional.one_hot(
|
|
260
|
+
selected_experts, num_classes=self.num_local_experts
|
|
261
|
+
).permute(2, 1, 0)
|
|
262
|
+
|
|
263
|
+
# Loop over all available experts in the model and perform the computation on each expert
|
|
264
|
+
for expert_idx in range(self.num_local_experts):
|
|
265
|
+
expert_layer = self.experts[expert_idx]
|
|
266
|
+
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
267
|
+
|
|
268
|
+
# Index the correct hidden states and compute the expert hidden state for
|
|
269
|
+
# the current expert. We need to make sure to multiply the output hidden
|
|
270
|
+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
|
271
|
+
current_state = hidden_states[None, top_x].reshape(-1, self.in_features)
|
|
272
|
+
if current_state.numel() == 0:
|
|
273
|
+
continue
|
|
274
|
+
current_hidden_states = (
|
|
275
|
+
expert_layer(current_state) * routing_weights[top_x, idx, None]
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# However `index_add_` only support torch tensors for indexing so we'll use
|
|
279
|
+
# the `top_x` tensor here.
|
|
280
|
+
final_hidden_states.index_add_(
|
|
281
|
+
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
|
282
|
+
)
|
|
283
|
+
final_hidden_states = final_hidden_states.reshape(
|
|
284
|
+
*input_shape[:-1], self.out_features
|
|
285
|
+
)
|
|
286
|
+
final_hidden_states = pretrained_out + final_hidden_states
|
|
287
|
+
return final_hidden_states
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def weight(self):
|
|
291
|
+
"""
|
|
292
|
+
Mimic linear layer. Bacause in some cases, user might indicate the device (or dtype of parameters) of the linear layer using `linear_layer.weight.device`
|
|
293
|
+
"""
|
|
294
|
+
return self.shared_linear.weight
|
|
295
|
+
|
|
296
|
+
@property
|
|
297
|
+
def bias(self):
|
|
298
|
+
return self.shared_linear.bias
|
|
299
|
+
|
|
300
|
+
def __repr__(self):
|
|
301
|
+
return (
|
|
302
|
+
f"SingularMoELinear("
|
|
303
|
+
f"in_features={self.shared_linear.in_features}, "
|
|
304
|
+
f"out_features={self.shared_linear.out_features}, "
|
|
305
|
+
f"num_local_experts={self.num_local_experts}, "
|
|
306
|
+
f"num_experts_per_tok={self.num_experts_per_tok}, "
|
|
307
|
+
f"rank_of_router={self.rank_of_router}, "
|
|
308
|
+
f"rank_of_expert={self.rank_of_expert}"
|
|
309
|
+
f")"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@torch.no_grad()
|
|
314
|
+
def upscale_to_smile_linear(
|
|
315
|
+
base: nn.Linear, experts: List[nn.Linear], target: SmileLinear, accelerator=None
|
|
316
|
+
):
|
|
317
|
+
"""
|
|
318
|
+
Upscale a base linear layer to a SmileLinear layer using expert models.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
base (nn.Linear): The base linear layer.
|
|
322
|
+
experts (List[nn.Linear]): A list of expert linear layers.
|
|
323
|
+
target (SmileLinear): The target SmileLinear layer.
|
|
324
|
+
accelerator (optional): The device to perform the computation on. Defaults to None.
|
|
325
|
+
|
|
326
|
+
Returns:
|
|
327
|
+
SmileLinear: The upscaled SmileLinear layer.
|
|
328
|
+
"""
|
|
329
|
+
w = base.weight
|
|
330
|
+
w_ft_list = [e.weight for e in experts]
|
|
331
|
+
dw_list = [w_ft - w for w_ft in w_ft_list]
|
|
332
|
+
|
|
333
|
+
if _is_all_zeros(dw_list):
|
|
334
|
+
raise ExpertNotTrainedError("Expert models are not trained")
|
|
335
|
+
|
|
336
|
+
rank_of_router = target.rank_of_router
|
|
337
|
+
rank_of_expert = target.rank_of_expert
|
|
338
|
+
num_local_experts = target.num_local_experts
|
|
339
|
+
svd_list = [svd(dw, accelerator=accelerator) for dw in dw_list]
|
|
340
|
+
|
|
341
|
+
# gate
|
|
342
|
+
gate_weight = []
|
|
343
|
+
for u, s, v in svd_list:
|
|
344
|
+
gate_weight.append(v[:, :rank_of_router].T)
|
|
345
|
+
gate_weight = (
|
|
346
|
+
torch.stack(gate_weight, dim=0)
|
|
347
|
+
.reshape(num_local_experts * rank_of_router, -1)
|
|
348
|
+
.contiguous()
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
target.gate.load_state_dict({"weight": gate_weight})
|
|
352
|
+
|
|
353
|
+
# shared linear
|
|
354
|
+
target.shared_linear.load_state_dict(base.state_dict())
|
|
355
|
+
|
|
356
|
+
# experts
|
|
357
|
+
if rank_of_expert > 0:
|
|
358
|
+
for expert_idx, target_expert in enumerate(target.experts):
|
|
359
|
+
u, s, v = svd_list[expert_idx]
|
|
360
|
+
u = u[:, :rank_of_expert]
|
|
361
|
+
s = s[:rank_of_expert]
|
|
362
|
+
v = v[:, :rank_of_expert]
|
|
363
|
+
state_dict = {"u": u, "svh": (s * v).T}
|
|
364
|
+
if experts[expert_idx].bias is not None:
|
|
365
|
+
state_dict["bias"] = experts[expert_idx].bias.data
|
|
366
|
+
target_expert.load_state_dict(state_dict)
|
|
367
|
+
else:
|
|
368
|
+
for expert_idx, target_expert in enumerate(target.experts):
|
|
369
|
+
target_expert.load_state_dict(
|
|
370
|
+
state_dict_sub(experts[expert_idx].state_dict(), base.state_dict())
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
return target
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Dict, List, Tuple # noqa: F401
|
|
2
|
+
from typing import Dict, List, Optional, Tuple, Union # noqa: F401
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
5
|
import torch.nn.functional as F
|
|
6
6
|
from torch import Tensor, nn
|
|
7
7
|
|
|
8
|
+
from .utils import _is_all_zeros, svd
|
|
9
|
+
|
|
8
10
|
log = logging.getLogger(__name__)
|
|
9
11
|
|
|
10
12
|
|
|
@@ -12,50 +14,42 @@ class ExpertNotTrainedError(Exception):
|
|
|
12
14
|
pass
|
|
13
15
|
|
|
14
16
|
|
|
15
|
-
def _is_all_zeros(tensor: Tensor | List[Tensor]) -> bool:
|
|
16
|
-
if isinstance(tensor, Tensor):
|
|
17
|
-
return torch.allclose(tensor, torch.zeros_like(tensor))
|
|
18
|
-
else:
|
|
19
|
-
return all(_is_all_zeros(t) for t in tensor)
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
def _svd(w: Tensor, full_matrices=True) -> Tuple[Tensor, Tensor, Tensor]:
|
|
23
|
-
u, s, vh = torch.linalg.svd(
|
|
24
|
-
w, full_matrices=full_matrices, driver="gesvd" if w.is_cuda else None
|
|
25
|
-
)
|
|
26
|
-
v = vh.T
|
|
27
|
-
return u, s, v
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def svd(
|
|
31
|
-
w: Tensor, full_matrices=True, accelerator=None
|
|
32
|
-
) -> Tuple[Tensor, Tensor, Tensor]:
|
|
33
|
-
if accelerator is None:
|
|
34
|
-
return _svd(w, full_matrices=full_matrices)
|
|
35
|
-
original_device = w.device
|
|
36
|
-
w = w.to(accelerator)
|
|
37
|
-
u, s, v = _svd(w)
|
|
38
|
-
return u.to(original_device), s.to(original_device), v.to(original_device)
|
|
39
|
-
|
|
40
|
-
|
|
41
17
|
class SmileGate(nn.Module):
|
|
18
|
+
__constants__ = ["in_features", "num_experts", "k"]
|
|
19
|
+
in_features: int
|
|
20
|
+
num_experts: int
|
|
21
|
+
k: int
|
|
22
|
+
weight: nn.Parameter
|
|
23
|
+
|
|
42
24
|
def __init__(
|
|
43
25
|
self,
|
|
44
26
|
input_features: int,
|
|
45
27
|
w_diff_list: List[Tensor],
|
|
46
28
|
k: int,
|
|
47
|
-
|
|
29
|
+
svd_cache: List[
|
|
30
|
+
Tuple[Tensor, Tensor, Tensor]
|
|
31
|
+
] = None, # cached `svd_cache`, pass it to avoid recomputing
|
|
48
32
|
upscaling_accelerator=None,
|
|
49
33
|
):
|
|
34
|
+
R"""
|
|
35
|
+
This constructs weights through SVD decomposition.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
input_features: The dimension of input features.
|
|
39
|
+
w_diff_list: The list of weight matrices to be decomposed.
|
|
40
|
+
k: The number of singular values to keep.
|
|
41
|
+
svd_cache: The cached SVD decomposition results. If not provided, the SVD decomposition will be computed on the fly.
|
|
42
|
+
upscaling_accelerator: The accelerator to use for SVD decomposition.
|
|
43
|
+
"""
|
|
50
44
|
super().__init__()
|
|
51
45
|
self.input_features = input_features
|
|
52
46
|
self.num_experts = len(w_diff_list)
|
|
53
47
|
weights = []
|
|
54
48
|
for i, w_diff in enumerate(w_diff_list):
|
|
55
|
-
if
|
|
49
|
+
if svd_cache is None:
|
|
56
50
|
u, s, v = svd(w_diff, accelerator=upscaling_accelerator)
|
|
57
51
|
else:
|
|
58
|
-
u, s, v =
|
|
52
|
+
u, s, v = svd_cache[i]
|
|
59
53
|
u = u[:, :k]
|
|
60
54
|
s = s[:k]
|
|
61
55
|
v = v[:, :k]
|
|
@@ -86,8 +80,38 @@ class SmileGate(nn.Module):
|
|
|
86
80
|
|
|
87
81
|
|
|
88
82
|
class SmileCompressedLinear(nn.Module):
|
|
89
|
-
|
|
83
|
+
"""
|
|
84
|
+
This module is used to compress a linear layer using SVD decomposition.
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
__constants__ = ["in_features", "out_features", "k"]
|
|
88
|
+
in_features: int
|
|
89
|
+
out_features: int
|
|
90
|
+
k: int
|
|
91
|
+
|
|
92
|
+
u: nn.Parameter
|
|
93
|
+
svh: nn.Parameter
|
|
94
|
+
bias: Optional[nn.Parameter]
|
|
95
|
+
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
model: nn.Linear,
|
|
99
|
+
k: int,
|
|
100
|
+
svd_cache: Optional[Tuple[Tensor, Tensor, Tensor]] = None,
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Initialize the SmileCompressedLinear module.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
model (nn.Linear): The linear model to compress.
|
|
107
|
+
k (int): The number of singular values to keep.
|
|
108
|
+
svd_cache (Tuple[Tensor, Tensor, Tensor]): Cached SVD results.
|
|
109
|
+
"""
|
|
90
110
|
super().__init__()
|
|
111
|
+
self.in_features = model.in_features
|
|
112
|
+
self.out_features = model.out_features
|
|
113
|
+
self.k = k
|
|
114
|
+
|
|
91
115
|
if svd_cache is None:
|
|
92
116
|
u, s, v = svd(model.weight)
|
|
93
117
|
else:
|
|
@@ -106,12 +130,36 @@ class SmileCompressedLinear(nn.Module):
|
|
|
106
130
|
self.register_parameter("bias", None)
|
|
107
131
|
|
|
108
132
|
def forward(self, x):
|
|
133
|
+
"""
|
|
134
|
+
Forward pass of the SmileCompressedLinear module.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
x (Tensor): The input tensor.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Tensor: The output tensor.
|
|
141
|
+
"""
|
|
109
142
|
x = F.linear(x, self.svh)
|
|
110
143
|
x = F.linear(x, self.u, self.bias)
|
|
111
144
|
return x
|
|
112
145
|
|
|
113
146
|
|
|
114
147
|
class SmileMoELinear(nn.Module):
|
|
148
|
+
__constants__ = [
|
|
149
|
+
"in_features",
|
|
150
|
+
"out_features",
|
|
151
|
+
"num_experts",
|
|
152
|
+
"top_k",
|
|
153
|
+
"gate_k",
|
|
154
|
+
"k",
|
|
155
|
+
]
|
|
156
|
+
in_features: int
|
|
157
|
+
out_features: int
|
|
158
|
+
num_experts: int
|
|
159
|
+
top_k: int
|
|
160
|
+
gate_k: int
|
|
161
|
+
k: int
|
|
162
|
+
|
|
115
163
|
@torch.no_grad()
|
|
116
164
|
def __init__(
|
|
117
165
|
self,
|
|
@@ -124,6 +172,19 @@ class SmileMoELinear(nn.Module):
|
|
|
124
172
|
upscaling_accelerator=None,
|
|
125
173
|
routing_use_diff=True,
|
|
126
174
|
):
|
|
175
|
+
"""
|
|
176
|
+
Initialize the SmileMoELinear module.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
pretrained_model (nn.Linear): The pretrained linear model.
|
|
180
|
+
finetuned_models (List[nn.Linear]): A list of fine-tuned linear models.
|
|
181
|
+
gate_k (int): The number of singular values to keep for the gate.
|
|
182
|
+
k (int): The number of singular values to keep for the experts.
|
|
183
|
+
top_k (int): The number of top experts to select.
|
|
184
|
+
full_matrices (bool): Whether to compute the full-sized U and V matrices.
|
|
185
|
+
upscaling_accelerator (str): The device to perform the computation on.
|
|
186
|
+
routing_use_diff (bool): Whether to use weight differences for routing.
|
|
187
|
+
"""
|
|
127
188
|
super().__init__()
|
|
128
189
|
self.num_experts = len(finetuned_models)
|
|
129
190
|
self.top_k = top_k
|
|
@@ -149,7 +210,7 @@ class SmileMoELinear(nn.Module):
|
|
|
149
210
|
input_features=self.in_features,
|
|
150
211
|
w_diff_list=w_diff_list,
|
|
151
212
|
k=gate_k,
|
|
152
|
-
|
|
213
|
+
svd_cache=svd_cache_list,
|
|
153
214
|
upscaling_accelerator=upscaling_accelerator,
|
|
154
215
|
)
|
|
155
216
|
else:
|
|
@@ -157,7 +218,7 @@ class SmileMoELinear(nn.Module):
|
|
|
157
218
|
input_features=self.in_features,
|
|
158
219
|
w_diff_list=[m.weight for m in finetuned_models],
|
|
159
220
|
k=gate_k,
|
|
160
|
-
|
|
221
|
+
svd_cache=None,
|
|
161
222
|
upscaling_accelerator=upscaling_accelerator,
|
|
162
223
|
)
|
|
163
224
|
|
|
@@ -181,6 +242,15 @@ class SmileMoELinear(nn.Module):
|
|
|
181
242
|
self.pretrained_model = pretrained_model
|
|
182
243
|
|
|
183
244
|
def forward(self, hidden_states: Tensor):
|
|
245
|
+
"""
|
|
246
|
+
Forward pass of the SmileMoELinear module.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
hidden_states (Tensor): The input tensor.
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
Tensor: The output tensor.
|
|
253
|
+
"""
|
|
184
254
|
pretrained_out = self.pretrained_model(hidden_states)
|
|
185
255
|
|
|
186
256
|
input_shape = hidden_states.size()
|