fusion-bench 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +22 -2
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +6 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/constants/runtime.py +57 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +24 -5
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +5 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/classification/clip_finetune.py +1 -1
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -6
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +17 -13
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +12 -16
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/causal_lm_upscaling.py +371 -0
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +144 -0
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +5 -1
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +71 -51
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/__init__.py +1 -0
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/entropy_loss.py +25 -0
- fusion_bench/method/we_moe/flan_t5_we_moe.py +331 -0
- fusion_bench/method/we_moe/utils.py +15 -0
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +15 -45
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +275 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +41 -33
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +7 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +160 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/model_card_templates/default.md +46 -0
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +7 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +698 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +7 -12
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +46 -61
- fusion_bench/scripts/cli.py +38 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -22
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +7 -1
- fusion_bench/utils/cache_utils.py +101 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/fabric.py +2 -2
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/lazy_imports.py +23 -0
- fusion_bench/utils/lazy_state_dict.py +38 -3
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/path.py +56 -0
- fusion_bench/utils/pylogger.py +1 -1
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/METADATA +24 -47
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/RECORD +184 -145
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/fisher_merging/clip_fisher_merging.yaml +0 -1
- fusion_bench_config/method/linear/simple_average_for_llama.yaml +3 -2
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +21 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +2 -1
- fusion_bench_config/method/wemoe/flan_t5_weight_ensembling_moe.yaml +20 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +3 -3
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.22.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -0,0 +1,371 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, Union
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from accelerate import init_empty_weights
|
|
8
|
+
from tqdm.auto import tqdm
|
|
9
|
+
from transformers import (
|
|
10
|
+
AutoConfig,
|
|
11
|
+
AutoModelForCausalLM,
|
|
12
|
+
AutoTokenizer,
|
|
13
|
+
LlamaForCausalLM,
|
|
14
|
+
MistralForCausalLM,
|
|
15
|
+
PretrainedConfig,
|
|
16
|
+
PreTrainedModel,
|
|
17
|
+
Qwen2ForCausalLM,
|
|
18
|
+
)
|
|
19
|
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
|
20
|
+
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
|
|
21
|
+
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
|
22
|
+
|
|
23
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
24
|
+
from fusion_bench.compat.modelpool import to_modelpool
|
|
25
|
+
from fusion_bench.constants import RuntimeConstants
|
|
26
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
27
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
28
|
+
from fusion_bench.models.hf_utils import (
|
|
29
|
+
create_default_model_card,
|
|
30
|
+
save_pretrained_with_remote_code,
|
|
31
|
+
)
|
|
32
|
+
from fusion_bench.models.modeling_smile_llama import (
|
|
33
|
+
SmileLlamaConfig,
|
|
34
|
+
SmileLlamaForCausalLM,
|
|
35
|
+
SmileLlamaModel,
|
|
36
|
+
)
|
|
37
|
+
from fusion_bench.models.modeling_smile_llama.modeling_smile_llama import (
|
|
38
|
+
SmileLlamaDecoderLayer,
|
|
39
|
+
)
|
|
40
|
+
from fusion_bench.models.modeling_smile_mistral import (
|
|
41
|
+
SmileMistralConfig,
|
|
42
|
+
SmileMistralForCausalLM,
|
|
43
|
+
SmileMistralModel,
|
|
44
|
+
)
|
|
45
|
+
from fusion_bench.models.modeling_smile_mistral.modeling_smile_mistral import (
|
|
46
|
+
SmileMistralDecoderLayer,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
# Import all SMILE configurations and models
|
|
50
|
+
from fusion_bench.models.modeling_smile_qwen2 import (
|
|
51
|
+
SmileQwen2Config,
|
|
52
|
+
SmileQwen2ForCausalLM,
|
|
53
|
+
SmileQwen2Model,
|
|
54
|
+
)
|
|
55
|
+
from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
|
|
56
|
+
SmileQwen2DecoderLayer,
|
|
57
|
+
)
|
|
58
|
+
from fusion_bench.models.smile_moe.linear_from_hf_config import (
|
|
59
|
+
ExpertNotTrainedError,
|
|
60
|
+
upscale_to_smile_linear,
|
|
61
|
+
)
|
|
62
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
63
|
+
from fusion_bench.utils.parameters import print_parameters
|
|
64
|
+
|
|
65
|
+
log = logging.getLogger(__name__)
|
|
66
|
+
|
|
67
|
+
# Model type mappings
|
|
68
|
+
MODEL_TYPE_MAPPINGS = {
|
|
69
|
+
"qwen2": {
|
|
70
|
+
"base_model_cls": Qwen2ForCausalLM,
|
|
71
|
+
"base_decoder_layer_cls": Qwen2DecoderLayer,
|
|
72
|
+
"smile_config_cls": SmileQwen2Config,
|
|
73
|
+
"smile_model_cls": SmileQwen2ForCausalLM,
|
|
74
|
+
"smile_base_model_cls": SmileQwen2Model,
|
|
75
|
+
"smile_decoder_layer_cls": SmileQwen2DecoderLayer,
|
|
76
|
+
"description": "Qwen2",
|
|
77
|
+
},
|
|
78
|
+
"llama": {
|
|
79
|
+
"base_model_cls": LlamaForCausalLM,
|
|
80
|
+
"base_decoder_layer_cls": LlamaDecoderLayer,
|
|
81
|
+
"smile_config_cls": SmileLlamaConfig,
|
|
82
|
+
"smile_model_cls": SmileLlamaForCausalLM,
|
|
83
|
+
"smile_base_model_cls": SmileLlamaModel,
|
|
84
|
+
"smile_decoder_layer_cls": SmileLlamaDecoderLayer,
|
|
85
|
+
"description": "Llama",
|
|
86
|
+
},
|
|
87
|
+
"mistral": {
|
|
88
|
+
"base_model_cls": MistralForCausalLM,
|
|
89
|
+
"base_decoder_layer_cls": MistralDecoderLayer,
|
|
90
|
+
"smile_config_cls": SmileMistralConfig,
|
|
91
|
+
"smile_model_cls": SmileMistralForCausalLM,
|
|
92
|
+
"smile_base_model_cls": SmileMistralModel,
|
|
93
|
+
"smile_decoder_layer_cls": SmileMistralDecoderLayer,
|
|
94
|
+
"description": "Mistral",
|
|
95
|
+
},
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def detect_model_type(
|
|
100
|
+
model_or_config: Union[PreTrainedModel, PretrainedConfig, str],
|
|
101
|
+
) -> str:
|
|
102
|
+
"""
|
|
103
|
+
Detect the model type from a model, config, or model name/path.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
model_or_config: Model, config, or model name/path to detect type from
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
str: The detected model type ("qwen2", "llama", "mistral")
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
ValueError: If model type cannot be detected or is not supported
|
|
113
|
+
"""
|
|
114
|
+
if isinstance(model_or_config, str):
|
|
115
|
+
# Load config from path/name
|
|
116
|
+
config = AutoConfig.from_pretrained(model_or_config)
|
|
117
|
+
elif isinstance(model_or_config, PreTrainedModel):
|
|
118
|
+
config = model_or_config.config
|
|
119
|
+
elif isinstance(model_or_config, PretrainedConfig):
|
|
120
|
+
config = model_or_config
|
|
121
|
+
else:
|
|
122
|
+
raise ValueError(
|
|
123
|
+
f"Unsupported type for model type detection: {type(model_or_config)}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
model_type = getattr(config, "model_type", "").lower()
|
|
127
|
+
|
|
128
|
+
# Handle various model type variations
|
|
129
|
+
if model_type in MODEL_TYPE_MAPPINGS:
|
|
130
|
+
return model_type
|
|
131
|
+
else:
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f"Unsupported model type: {model_type}. Supported types: {list(MODEL_TYPE_MAPPINGS.keys())}"
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
@auto_register_config
|
|
138
|
+
class SmileCausalLMUpscalingAlgorithm(
|
|
139
|
+
SimpleProfilerMixin,
|
|
140
|
+
BaseAlgorithm,
|
|
141
|
+
):
|
|
142
|
+
R"""
|
|
143
|
+
SmileCausalLMUpscalingAlgorithm is a generic model fusion algorithm designed to upscale
|
|
144
|
+
a pretrained CausalLM model using a set of fine-tuned expert models. The algorithm
|
|
145
|
+
supports Qwen2, Llama, and Mistral model architectures and leverages Singular Value
|
|
146
|
+
Decomposition (SVD) to merge the weights of the pretrained model and the expert models
|
|
147
|
+
into a new upscaled model.
|
|
148
|
+
|
|
149
|
+
The algorithm automatically detects the model type and uses the appropriate SMILE
|
|
150
|
+
configuration and model classes.
|
|
151
|
+
|
|
152
|
+
Methods:
|
|
153
|
+
run(modelpool: BaseModelPool) -> Union[SmileQwen2ForCausalLM, SmileLlamaForCausalLM, SmileMistralForCausalLM]:
|
|
154
|
+
Executes the upscaling process and returns the upscaled model.
|
|
155
|
+
|
|
156
|
+
merge(pretrained_model: PreTrainedModel, finetuned_models: List[PreTrainedModel]) -> PreTrainedModel:
|
|
157
|
+
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
158
|
+
"""
|
|
159
|
+
|
|
160
|
+
modelpool: CausalLMPool
|
|
161
|
+
|
|
162
|
+
def __init__(
|
|
163
|
+
self,
|
|
164
|
+
device,
|
|
165
|
+
accelerator,
|
|
166
|
+
model_save_path,
|
|
167
|
+
model_dtype,
|
|
168
|
+
num_experts_per_tok,
|
|
169
|
+
rank_of_router,
|
|
170
|
+
rank_of_expert,
|
|
171
|
+
save_with_remote_code: bool = True,
|
|
172
|
+
model_type: str = None, # Optional: explicitly specify model type
|
|
173
|
+
**kwargs,
|
|
174
|
+
):
|
|
175
|
+
super().__init__(**kwargs)
|
|
176
|
+
self.model_mappings = None # Will be set during run()
|
|
177
|
+
|
|
178
|
+
if not torch.cuda.is_available():
|
|
179
|
+
if "cuda" in self.device:
|
|
180
|
+
self.device = "cpu"
|
|
181
|
+
if "cuda" in self.accelerator:
|
|
182
|
+
self.accelerator = "cpu"
|
|
183
|
+
|
|
184
|
+
@torch.no_grad()
|
|
185
|
+
def run(self, modelpool) -> PreTrainedModel:
|
|
186
|
+
"""
|
|
187
|
+
Executes the upscaling process.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
modelpool (ModelPool): The pool of models to be used for upscaling.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
PreTrainedModel: The upscaled model (specific type depends on detected model architecture).
|
|
194
|
+
"""
|
|
195
|
+
self.modelpool = modelpool = to_modelpool(modelpool)
|
|
196
|
+
config = self.config
|
|
197
|
+
|
|
198
|
+
# Auto-detect model type if not specified
|
|
199
|
+
if self.model_type is None:
|
|
200
|
+
self.model_type = detect_model_type(
|
|
201
|
+
modelpool.get_model_path("_pretrained_")
|
|
202
|
+
)
|
|
203
|
+
log.info(f"Auto-detected model type: {self.model_type}")
|
|
204
|
+
|
|
205
|
+
# Get the appropriate model mappings
|
|
206
|
+
if self.model_type not in MODEL_TYPE_MAPPINGS:
|
|
207
|
+
raise ValueError(
|
|
208
|
+
f"Unsupported model type: {self.model_type}. Supported: {list(MODEL_TYPE_MAPPINGS.keys())}"
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
self.model_mappings = MODEL_TYPE_MAPPINGS[self.model_type]
|
|
212
|
+
log.info(f"Using {self.model_mappings['description']} model architecture")
|
|
213
|
+
|
|
214
|
+
with self.profile("load pretrained model"):
|
|
215
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
216
|
+
|
|
217
|
+
with self.profile("load fine-tuned model"):
|
|
218
|
+
finetuned_models = [
|
|
219
|
+
m for m in tqdm(modelpool.models(), total=len(modelpool.model_names))
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
if self.device == "cuda" and torch.cuda.is_available():
|
|
223
|
+
pretrained_model = pretrained_model.cuda()
|
|
224
|
+
print("parameter count of pretrained model:")
|
|
225
|
+
print_parameters(pretrained_model)
|
|
226
|
+
finetuned_models = [m.cuda() for m in finetuned_models]
|
|
227
|
+
|
|
228
|
+
with self.profile("merge model"):
|
|
229
|
+
model = self.merge(pretrained_model, finetuned_models)
|
|
230
|
+
|
|
231
|
+
self.print_profile_summary()
|
|
232
|
+
print("parameter count of upscaled MoE model:")
|
|
233
|
+
print_parameters(model)
|
|
234
|
+
print(model)
|
|
235
|
+
|
|
236
|
+
if self.model_dtype is not None:
|
|
237
|
+
model.to(dtype=parse_dtype(self.model_dtype))
|
|
238
|
+
|
|
239
|
+
if self.model_save_path is not None:
|
|
240
|
+
if os.path.dirname(self.model_save_path):
|
|
241
|
+
os.makedirs(os.path.dirname(self.model_save_path), exist_ok=True)
|
|
242
|
+
log.info(f"Saving model to {self.model_save_path}")
|
|
243
|
+
tokenizer = self.modelpool.load_tokenizer()
|
|
244
|
+
tokenizer.save_pretrained(self.model_save_path)
|
|
245
|
+
if not self.save_with_remote_code:
|
|
246
|
+
model.save_pretrained(self.model_save_path)
|
|
247
|
+
else:
|
|
248
|
+
# Use the appropriate auto_map for the detected model type
|
|
249
|
+
auto_map = {
|
|
250
|
+
"AutoConfig": self.model_mappings["smile_config_cls"],
|
|
251
|
+
"AutoModel": self.model_mappings["smile_base_model_cls"],
|
|
252
|
+
"AutoModelForCausalLM": self.model_mappings["smile_model_cls"],
|
|
253
|
+
}
|
|
254
|
+
save_pretrained_with_remote_code(
|
|
255
|
+
model,
|
|
256
|
+
auto_map=auto_map,
|
|
257
|
+
save_directory=self.model_save_path,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# save readme
|
|
261
|
+
model_card_str = create_default_model_card(
|
|
262
|
+
models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
|
|
263
|
+
description=f"Merged {self.model_mappings['description']} model using SMILE Upscaling",
|
|
264
|
+
algorithm_config=self.config,
|
|
265
|
+
modelpool_config=modelpool.config,
|
|
266
|
+
)
|
|
267
|
+
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
268
|
+
f.write(model_card_str)
|
|
269
|
+
|
|
270
|
+
return model
|
|
271
|
+
|
|
272
|
+
def merge(
|
|
273
|
+
self,
|
|
274
|
+
pretrained_model: PreTrainedModel,
|
|
275
|
+
finetuned_models: List[PreTrainedModel],
|
|
276
|
+
) -> PreTrainedModel:
|
|
277
|
+
"""
|
|
278
|
+
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
279
|
+
|
|
280
|
+
Args:
|
|
281
|
+
pretrained_model (PreTrainedModel): The pretrained model.
|
|
282
|
+
finetuned_models (List[PreTrainedModel]): A list of fine-tuned models.
|
|
283
|
+
|
|
284
|
+
Returns:
|
|
285
|
+
PreTrainedModel: The upscaled model (specific type depends on model architecture).
|
|
286
|
+
"""
|
|
287
|
+
with init_empty_weights():
|
|
288
|
+
pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
289
|
+
if isinstance(pretrained_model_config, str):
|
|
290
|
+
pretrained_path = pretrained_model_config
|
|
291
|
+
else:
|
|
292
|
+
pretrained_path = pretrained_model_config.get(
|
|
293
|
+
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
294
|
+
)
|
|
295
|
+
base_config = AutoConfig.from_pretrained(pretrained_path)
|
|
296
|
+
|
|
297
|
+
# Create the appropriate SMILE config for the detected model type
|
|
298
|
+
SmileConfigClass = self.model_mappings["smile_config_cls"]
|
|
299
|
+
model_config = SmileConfigClass(
|
|
300
|
+
num_experts_per_tok=self.num_experts_per_tok,
|
|
301
|
+
rank_of_router=self.rank_of_router,
|
|
302
|
+
rank_of_expert=self.rank_of_expert,
|
|
303
|
+
num_local_experts=len(finetuned_models),
|
|
304
|
+
**base_config.to_dict(),
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Create the appropriate SMILE model for the detected model type
|
|
308
|
+
SmileModelClass = self.model_mappings["smile_model_cls"]
|
|
309
|
+
model = SmileModelClass(model_config)
|
|
310
|
+
|
|
311
|
+
model.to(dtype=pretrained_model.dtype).to_empty(device="cpu")
|
|
312
|
+
|
|
313
|
+
# copy pretrained model weights
|
|
314
|
+
state_dict = model.state_dict()
|
|
315
|
+
pretrained_state_dict = pretrained_model.state_dict()
|
|
316
|
+
for key in list(pretrained_state_dict.keys()):
|
|
317
|
+
if key not in state_dict:
|
|
318
|
+
pretrained_state_dict.pop(key)
|
|
319
|
+
model.load_state_dict(pretrained_state_dict, strict=False)
|
|
320
|
+
|
|
321
|
+
# upscale model
|
|
322
|
+
BaseDecoderLayerClass = self.model_mappings["base_decoder_layer_cls"]
|
|
323
|
+
SmileDecoderLayerClass = self.model_mappings["smile_decoder_layer_cls"]
|
|
324
|
+
|
|
325
|
+
for layer_idx in tqdm(
|
|
326
|
+
range(len(pretrained_model.model.layers)),
|
|
327
|
+
"Upscaling Modules (layer)",
|
|
328
|
+
dynamic_ncols=True,
|
|
329
|
+
):
|
|
330
|
+
if RuntimeConstants.debug and layer_idx > 0:
|
|
331
|
+
log.info(
|
|
332
|
+
"Debug mode enabled: processing only the first layer, skipping remaining layers"
|
|
333
|
+
)
|
|
334
|
+
break
|
|
335
|
+
|
|
336
|
+
pretrained_layer = pretrained_model.model.layers[layer_idx]
|
|
337
|
+
finetuned_layers = [m.model.layers[layer_idx] for m in finetuned_models]
|
|
338
|
+
|
|
339
|
+
target_layer = model.model.layers[layer_idx]
|
|
340
|
+
|
|
341
|
+
for n in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
342
|
+
try:
|
|
343
|
+
upscale_to_smile_linear(
|
|
344
|
+
base=getattr(pretrained_layer.self_attn, n),
|
|
345
|
+
experts=[getattr(m.self_attn, n) for m in finetuned_layers],
|
|
346
|
+
target=getattr(target_layer.self_attn, n),
|
|
347
|
+
accelerator=self.accelerator,
|
|
348
|
+
)
|
|
349
|
+
except ExpertNotTrainedError:
|
|
350
|
+
setattr(
|
|
351
|
+
target_layer.self_attn,
|
|
352
|
+
n,
|
|
353
|
+
getattr(pretrained_layer.self_attn, n),
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
for n in ["gate_proj", "up_proj", "down_proj"]:
|
|
357
|
+
try:
|
|
358
|
+
upscale_to_smile_linear(
|
|
359
|
+
base=getattr(pretrained_layer.mlp, n),
|
|
360
|
+
experts=[getattr(m.mlp, n) for m in finetuned_layers],
|
|
361
|
+
target=getattr(target_layer.mlp, n),
|
|
362
|
+
accelerator=self.accelerator,
|
|
363
|
+
)
|
|
364
|
+
except ExpertNotTrainedError:
|
|
365
|
+
setattr(
|
|
366
|
+
target_layer.mlp,
|
|
367
|
+
n,
|
|
368
|
+
getattr(pretrained_layer.mlp, n),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
return model
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal, cast
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
from transformers import CLIPVisionModel
|
|
11
|
+
|
|
12
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
13
|
+
from fusion_bench.dataset import CLIPDataset
|
|
14
|
+
from fusion_bench.method import SmileUpscalingAlgorithm
|
|
15
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
16
|
+
from fusion_bench.modelpool import CLIPVisionModelPool
|
|
17
|
+
from fusion_bench.taskpool.clip_vision.taskpool import LayerWiseFeatureSaver
|
|
18
|
+
from fusion_bench.utils.devices import clear_cuda_cache
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@auto_register_config
|
|
22
|
+
class LowRankApproximation(BaseAlgorithm):
|
|
23
|
+
def __init__(self, rank: int, device: str = "cuda", **kwargs):
|
|
24
|
+
"""Low-rank approximation of fine-tuned updates."""
|
|
25
|
+
super().__init__(**kwargs)
|
|
26
|
+
|
|
27
|
+
def run(self, modelpool: BaseModelPool):
|
|
28
|
+
# Implement low-rank approximation logic here
|
|
29
|
+
base_model = modelpool.load_pretrained_model()
|
|
30
|
+
|
|
31
|
+
models = {}
|
|
32
|
+
for model_name in tqdm(modelpool.model_names, "processing models"):
|
|
33
|
+
task_model = modelpool.load_model(model_name)
|
|
34
|
+
for module_name, module in task_model.named_modules():
|
|
35
|
+
if isinstance(module, nn.Linear):
|
|
36
|
+
w = cast(
|
|
37
|
+
nn.Linear, base_model.get_submodule(module_name)
|
|
38
|
+
).weight.to(dtype=torch.float32, device=self.device, copy=True)
|
|
39
|
+
w_ft = module.weight.to(
|
|
40
|
+
dtype=torch.float32, device=self.device, copy=True
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
# Compute low-rank approximation
|
|
44
|
+
w_diff = w_ft - w
|
|
45
|
+
u, s, vh = torch.linalg.svd(w_diff)
|
|
46
|
+
v = vh.T
|
|
47
|
+
|
|
48
|
+
u = u[:, : self.rank]
|
|
49
|
+
s = s[: self.rank]
|
|
50
|
+
v = v[:, : self.rank]
|
|
51
|
+
|
|
52
|
+
low_rank_w_diff = torch.linalg.multi_dot((u, torch.diag(s), v.T))
|
|
53
|
+
low_rank_w = w + low_rank_w_diff
|
|
54
|
+
|
|
55
|
+
module.weight.data = low_rank_w.to(
|
|
56
|
+
dtype=module.weight.dtype,
|
|
57
|
+
device=module.weight.device,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
models[model_name] = task_model
|
|
61
|
+
return models
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@auto_register_config
|
|
65
|
+
class ErrorAccumulationAnalysisForCLIP(
|
|
66
|
+
LightningFabricMixin,
|
|
67
|
+
BaseAlgorithm,
|
|
68
|
+
):
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
gate_k: int,
|
|
72
|
+
k: int,
|
|
73
|
+
seed: int = 42,
|
|
74
|
+
top_k: int = 1,
|
|
75
|
+
dataset_kwargs: DictConfig = None,
|
|
76
|
+
max_samples: int = 1024,
|
|
77
|
+
**kwargs,
|
|
78
|
+
):
|
|
79
|
+
super().__init__(**kwargs)
|
|
80
|
+
if dataset_kwargs is None:
|
|
81
|
+
self.dataset_kwargs = DictConfig(
|
|
82
|
+
{
|
|
83
|
+
"batch_size": 32,
|
|
84
|
+
"num_workers": 4,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
def run(self, modelpool: CLIPVisionModelPool):
|
|
89
|
+
assert self.fabric.world_size == 1, "Distributed inference is not supported."
|
|
90
|
+
# get the smile model
|
|
91
|
+
smile_algorithm = SmileUpscalingAlgorithm(
|
|
92
|
+
gate_k=self.gate_k, k=self.k, top_k=self.top_k, device=self.fabric.device
|
|
93
|
+
)
|
|
94
|
+
smile_model = smile_algorithm.run(modelpool)
|
|
95
|
+
# get low-rank models
|
|
96
|
+
low_rank_models = LowRankApproximation(rank=self.k).run(modelpool)
|
|
97
|
+
|
|
98
|
+
results = {
|
|
99
|
+
"model_name": [],
|
|
100
|
+
"method": [],
|
|
101
|
+
"layer_index": [],
|
|
102
|
+
"approximation_error": [],
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
for model_name in modelpool.model_names:
|
|
106
|
+
dataset = modelpool.load_test_dataset(model_name)
|
|
107
|
+
processor = modelpool.load_processor()
|
|
108
|
+
dataset = CLIPDataset(dataset, processor)
|
|
109
|
+
dataloader = DataLoader(dataset, shuffle=True, **self.dataset_kwargs)
|
|
110
|
+
dataloader = self.fabric.setup_dataloaders(dataloader)
|
|
111
|
+
|
|
112
|
+
# finetuned_model
|
|
113
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
114
|
+
finetuned_model = self.to_device(finetuned_model)
|
|
115
|
+
self.collect_hidden_states(
|
|
116
|
+
finetuned_model,
|
|
117
|
+
dataloader=dataloader,
|
|
118
|
+
model_name=f"{model_name}/finetuned",
|
|
119
|
+
)
|
|
120
|
+
del finetuned_model
|
|
121
|
+
clear_cuda_cache()
|
|
122
|
+
|
|
123
|
+
# smile model
|
|
124
|
+
smile_model = self.to_device(smile_model)
|
|
125
|
+
self.collect_hidden_states(
|
|
126
|
+
smile_model, dataloader=dataloader, model_name=f"{model_name}/smile"
|
|
127
|
+
)
|
|
128
|
+
smile_model.cpu()
|
|
129
|
+
clear_cuda_cache()
|
|
130
|
+
|
|
131
|
+
# low-rank models
|
|
132
|
+
model = low_rank_models.pop(model_name)
|
|
133
|
+
model = self.to_device(model)
|
|
134
|
+
self.collect_hidden_states(
|
|
135
|
+
model, dataloader=dataloader, model_name=f"{model_name}/low-rank"
|
|
136
|
+
)
|
|
137
|
+
del model
|
|
138
|
+
clear_cuda_cache()
|
|
139
|
+
|
|
140
|
+
del dataloader
|
|
141
|
+
clear_cuda_cache()
|
|
142
|
+
|
|
143
|
+
@torch.no_grad()
|
|
144
|
+
def collect_hidden_states(
|
|
145
|
+
self, model: CLIPVisionModel, dataloader, model_name: str
|
|
146
|
+
):
|
|
147
|
+
self.fabric.seed_everything(
|
|
148
|
+
self.seed, workers=True
|
|
149
|
+
) # make sure to get same data samples
|
|
150
|
+
# register hooks
|
|
151
|
+
hooks = {}
|
|
152
|
+
hook_handles = {}
|
|
153
|
+
for i, layer in enumerate(model.vision_model.encoder.layers):
|
|
154
|
+
hooks[i] = LayerWiseFeatureSaver(
|
|
155
|
+
save_path=os.path.join(self.log_dir, model_name, f"layer_{i}.pth"),
|
|
156
|
+
first_token_only=True,
|
|
157
|
+
)
|
|
158
|
+
hook_handles[i] = layer.register_forward_hook(hooks[i])
|
|
159
|
+
|
|
160
|
+
# forward pass
|
|
161
|
+
num_total_samples = 0
|
|
162
|
+
for images, _ in tqdm(dataloader, desc=f"Collecting features for {model_name}"):
|
|
163
|
+
batch_size = images.size(0)
|
|
164
|
+
model(images)
|
|
165
|
+
num_total_samples += batch_size
|
|
166
|
+
if num_total_samples >= self.max_samples:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
# save features
|
|
170
|
+
for i, hook in hooks.items():
|
|
171
|
+
hook.save_features()
|
|
172
|
+
|
|
173
|
+
# remove hooks
|
|
174
|
+
for i, hook_handle in hook_handles.items():
|
|
175
|
+
hook_handle.remove()
|
|
176
|
+
|
|
177
|
+
return hooks
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
|
|
8
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
9
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ProjectedEnergyAnalysis(
|
|
13
|
+
SimpleProfilerMixin,
|
|
14
|
+
LightningFabricMixin,
|
|
15
|
+
BaseAlgorithm,
|
|
16
|
+
):
|
|
17
|
+
def on_run_start(self):
|
|
18
|
+
self.device = self.fabric.device
|
|
19
|
+
|
|
20
|
+
def run(self, modelpool: BaseModelPool):
|
|
21
|
+
with self.profile("model loading"):
|
|
22
|
+
base_model = modelpool.load_pretrained_model()
|
|
23
|
+
|
|
24
|
+
results = {
|
|
25
|
+
"model_name": [],
|
|
26
|
+
"module_index": [],
|
|
27
|
+
"module_name": [],
|
|
28
|
+
"projected_energy_I": [],
|
|
29
|
+
"projected_energy_II": [],
|
|
30
|
+
"projected_energy_II_III": [],
|
|
31
|
+
}
|
|
32
|
+
for model_name in tqdm(
|
|
33
|
+
modelpool.model_names,
|
|
34
|
+
"analyzing",
|
|
35
|
+
dynamic_ncols=True,
|
|
36
|
+
):
|
|
37
|
+
with self.profile("model loading"):
|
|
38
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
39
|
+
|
|
40
|
+
module_index = 0
|
|
41
|
+
for module_name, base_module in tqdm(
|
|
42
|
+
list(base_model.named_modules()),
|
|
43
|
+
"analyzing modules",
|
|
44
|
+
dynamic_ncols=True,
|
|
45
|
+
):
|
|
46
|
+
if isinstance(base_module, torch.nn.Linear):
|
|
47
|
+
with self.profile("weight analysis"):
|
|
48
|
+
_result = self.analyze_weight(
|
|
49
|
+
base_module.weight,
|
|
50
|
+
finetuned_model.get_submodule(module_name).weight,
|
|
51
|
+
)
|
|
52
|
+
results["model_name"].append(model_name)
|
|
53
|
+
results["module_index"].append(module_index)
|
|
54
|
+
results["module_name"].append(module_name)
|
|
55
|
+
for key, value in _result.items():
|
|
56
|
+
results[key].append(value)
|
|
57
|
+
|
|
58
|
+
module_index += 1
|
|
59
|
+
|
|
60
|
+
# save results as csv
|
|
61
|
+
results = pd.DataFrame(results)
|
|
62
|
+
results.to_csv(
|
|
63
|
+
os.path.join(self.log_dir, "projected_energy_analysis.csv"), index=True
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
self.print_profile_summary()
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
@torch.no_grad()
|
|
70
|
+
def analyze_weight(self, w: torch.Tensor, w_ft: torch.Tensor, k: int = -1):
|
|
71
|
+
w = w.to(dtype=torch.float32, device=self.device)
|
|
72
|
+
w_ft = w_ft.to(dtype=torch.float32, device=self.device)
|
|
73
|
+
w_diff = w_ft - w
|
|
74
|
+
|
|
75
|
+
# Perform analysis on the weight tensor
|
|
76
|
+
u, s, vh = torch.linalg.svd(w, full_matrices=False)
|
|
77
|
+
v = vh.T
|
|
78
|
+
if k < 0:
|
|
79
|
+
# find the position where the sum of singular values is larger than 50% of the total sum
|
|
80
|
+
cumsum = s.cumsum(0)
|
|
81
|
+
k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1
|
|
82
|
+
|
|
83
|
+
# subspace I
|
|
84
|
+
w_diff_proj = self._project_subspace_low(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
85
|
+
projected_energy_I = (
|
|
86
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
87
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# subspace II
|
|
91
|
+
w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
92
|
+
projected_energy_II = (
|
|
93
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
94
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
## subspace II+III
|
|
98
|
+
u, s, vh = torch.linalg.svd(w, full_matrices=True)
|
|
99
|
+
v = vh.T
|
|
100
|
+
w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
101
|
+
projected_energy_II_III = (
|
|
102
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
103
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return {
|
|
107
|
+
"projected_energy_I": projected_energy_I.item(),
|
|
108
|
+
"projected_energy_II": projected_energy_II.item(),
|
|
109
|
+
"projected_energy_II_III": projected_energy_II_III.item(),
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
def _project_subspace_low(
|
|
113
|
+
self,
|
|
114
|
+
u: torch.Tensor,
|
|
115
|
+
s: torch.Tensor,
|
|
116
|
+
v: torch.Tensor,
|
|
117
|
+
k: int,
|
|
118
|
+
w: torch.Tensor,
|
|
119
|
+
w_ft: torch.Tensor,
|
|
120
|
+
):
|
|
121
|
+
u = u[:, :k]
|
|
122
|
+
s = s[:k]
|
|
123
|
+
v = v[:, :k]
|
|
124
|
+
|
|
125
|
+
w_diff = w_ft - w
|
|
126
|
+
w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
|
|
127
|
+
return w_diff_proj
|
|
128
|
+
|
|
129
|
+
def _project_subspace_high(
|
|
130
|
+
self,
|
|
131
|
+
u: torch.Tensor,
|
|
132
|
+
s: torch.Tensor,
|
|
133
|
+
v: torch.Tensor,
|
|
134
|
+
k: int,
|
|
135
|
+
w: torch.Tensor,
|
|
136
|
+
w_ft: torch.Tensor,
|
|
137
|
+
):
|
|
138
|
+
u = u[:, k:]
|
|
139
|
+
s = s[k:]
|
|
140
|
+
v = v[:, k:]
|
|
141
|
+
|
|
142
|
+
w_diff = w_ft - w
|
|
143
|
+
w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
|
|
144
|
+
return w_diff_proj
|