fusion-bench 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +1 -0
- fusion_bench/_get_started/__init__.py +3 -0
- fusion_bench/_get_started/greeting_program.py +49 -0
- fusion_bench/compat/method/base_algorithm.py +14 -0
- fusion_bench/constants/__init__.py +5 -0
- fusion_bench/constants/clip_vision.py +26 -2
- fusion_bench/constants/paths.py +4 -0
- fusion_bench/dataset/clip_dataset.py +2 -1
- fusion_bench/dataset/gpt2_glue.py +9 -9
- fusion_bench/dataset/image_corruption/__init__.py +0 -0
- fusion_bench/dataset/image_corruption/make_corruption.py +179 -0
- fusion_bench/dataset/image_dataset.py +1 -1
- fusion_bench/dataset/nyuv2.py +2 -2
- fusion_bench/method/__init__.py +16 -3
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +11 -7
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -5
- fusion_bench/method/base_algorithm.py +195 -12
- fusion_bench/method/bitdelta/__init__.py +4 -0
- fusion_bench/method/bitdelta/bitdelta.py +156 -0
- fusion_bench/method/bitdelta/bitdelta_utils/__init__.py +0 -0
- fusion_bench/method/bitdelta/bitdelta_utils/binary_gemm_kernel.py +462 -0
- fusion_bench/method/bitdelta/bitdelta_utils/data.py +35 -0
- fusion_bench/method/bitdelta/bitdelta_utils/diff.py +129 -0
- fusion_bench/method/concrete_subspace/clip_concrete_adamerging.py +0 -1
- fusion_bench/method/depth_upscaling/depth_upscaling.py +4 -9
- fusion_bench/method/doge_ta/clip_layer_wise_adamerging.py +4 -5
- fusion_bench/method/doge_ta/doge_ta.py +1 -1
- fusion_bench/method/ensemble.py +12 -12
- fusion_bench/method/expert_sparsity/utils/calibration_data.py +1 -1
- fusion_bench/method/fisher_merging/clip_fisher_merging.py +2 -2
- fusion_bench/method/fisher_merging/fisher_merging.py +6 -15
- fusion_bench/method/fisher_merging/gpt2_fisher_merging.py +3 -10
- fusion_bench/method/fw_merging/fw_hard.py +1 -1
- fusion_bench/method/fw_merging/fw_soft.py +1 -1
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +4 -5
- fusion_bench/method/linear/expo.py +2 -1
- fusion_bench/method/linear/linear_interpolation.py +6 -4
- fusion_bench/method/linear/simple_average_for_llama.py +2 -3
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +2 -2
- fusion_bench/method/mixture_of_experts/mixtral_upcycling.py +9 -26
- fusion_bench/method/model_recombination.py +2 -5
- fusion_bench/method/moe_pruner/hooks/__init__.py +1 -2
- fusion_bench/method/moe_pruner/utils/data.py +2 -1
- fusion_bench/method/moe_pruner/utils/prune.py +6 -1
- fusion_bench/method/pruning/llama_magnitude_prune.py +1 -1
- fusion_bench/method/pruning/wanda_utils/data.py +1 -2
- fusion_bench/method/pwe_moe/clip_pwe_moe.py +12 -34
- fusion_bench/method/randes/modelsoup.py +1 -3
- fusion_bench/method/regmean/clip_regmean.py +2 -2
- fusion_bench/method/regmean/gpt2_regmean.py +3 -10
- fusion_bench/method/regmean/regmean.py +2 -11
- fusion_bench/method/regmean_plusplus/__init__.py +1 -1
- fusion_bench/method/regmean_plusplus/clip_regmean_plusplus.py +24 -17
- fusion_bench/method/regmean_plusplus/regmean_plusplus.py +56 -38
- fusion_bench/method/simple_average.py +5 -9
- fusion_bench/method/slerp/slerp.py +5 -2
- fusion_bench/method/smile_upscaling/error_accumulation.py +177 -0
- fusion_bench/method/smile_upscaling/projected_energy.py +145 -0
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +39 -28
- fusion_bench/method/smile_upscaling/smile_upscaling.py +12 -5
- fusion_bench/method/tall_mask/task_arithmetic.py +3 -11
- fusion_bench/method/task_arithmetic/task_arithmetic.py +6 -10
- fusion_bench/method/ties_merging/ties_merging.py +13 -26
- fusion_bench/method/we_moe/clip_we_moe.py +5 -4
- fusion_bench/method/we_moe/we_moe.py +6 -6
- fusion_bench/method/weighted_average/llama.py +4 -16
- fusion_bench/metrics/continual_learning/__init__.py +1 -0
- fusion_bench/metrics/continual_learning/backward_transfer.py +1 -1
- fusion_bench/metrics/nyuv2/__init__.py +2 -2
- fusion_bench/metrics/nyuv2/segmentation.py +1 -1
- fusion_bench/mixins/__init__.py +10 -2
- fusion_bench/mixins/clip_classification.py +4 -3
- fusion_bench/mixins/hydra_config.py +105 -7
- fusion_bench/mixins/lightning_fabric.py +2 -0
- fusion_bench/mixins/serialization.py +265 -48
- fusion_bench/modelpool/__init__.py +2 -2
- fusion_bench/modelpool/base_pool.py +29 -9
- fusion_bench/modelpool/causal_lm/causal_lm.py +9 -0
- fusion_bench/modelpool/clip_vision/modelpool.py +1 -3
- fusion_bench/modelpool/seq_classification_lm/__init__.py +1 -1
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +1 -1
- fusion_bench/models/__init__.py +2 -1
- fusion_bench/models/expert_sparsity/mixtral/__init__.py +1 -1
- fusion_bench/models/hf_utils.py +182 -0
- fusion_bench/models/linearized/linearized_model_utils.py +4 -4
- fusion_bench/models/linearized/vision_model.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/__init__.py +1 -1
- fusion_bench/models/modeling_deepseek_v2/modeling_deepseek.py +4 -4
- fusion_bench/models/modeling_deepseek_v2/tokenization_deepseek_fast.py +0 -1
- fusion_bench/models/modeling_smile_gemma2/__init__.py +9 -0
- fusion_bench/models/modeling_smile_gemma2/configuration_smile_gemma2.py +20 -0
- fusion_bench/models/modeling_smile_gemma2/modeling_smile_gemma2.py +986 -0
- fusion_bench/models/modeling_smile_gemma2/register.py +26 -0
- fusion_bench/models/modeling_smile_llama/__init__.py +0 -0
- fusion_bench/models/modeling_smile_llama/configuration_smile_llama.py +20 -0
- fusion_bench/models/modeling_smile_llama/modeling_smile_llama.py +705 -0
- fusion_bench/models/modeling_smile_llama/register.py +8 -0
- fusion_bench/models/modeling_smile_mistral/__init__.py +5 -47
- fusion_bench/models/modeling_smile_qwen2/__init__.py +1 -1
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +6 -7
- fusion_bench/models/modeling_smile_qwen2/register.py +1 -4
- fusion_bench/models/parameter_dict.py +1 -1
- fusion_bench/models/sparse_we_moe.py +1 -53
- fusion_bench/models/utils.py +26 -0
- fusion_bench/models/we_moe.py +1 -53
- fusion_bench/models/wrappers/ensemble.py +6 -4
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -1
- fusion_bench/models/wrappers/task_wise_fusion.py +250 -72
- fusion_bench/programs/base_program.py +81 -2
- fusion_bench/programs/fabric_fusion_program.py +24 -8
- fusion_bench/scripts/cli.py +5 -5
- fusion_bench/taskpool/base_pool.py +4 -3
- fusion_bench/taskpool/clip_vision/taskpool.py +34 -18
- fusion_bench/taskpool/dummy.py +1 -1
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +1 -2
- fusion_bench/tasks/clip_classification/__init__.py +6 -4
- fusion_bench/utils/__init__.py +6 -1
- fusion_bench/utils/devices.py +14 -4
- fusion_bench/utils/instantiate_utils.py +3 -1
- fusion_bench/utils/modelscope.py +127 -8
- fusion_bench/utils/parameters.py +2 -2
- fusion_bench/utils/rich_utils.py +3 -0
- fusion_bench/utils/state_dict_arithmetic.py +25 -23
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/METADATA +24 -25
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/RECORD +165 -134
- fusion_bench_config/_get_started/clip_evaluate_single_model.yaml +21 -0
- fusion_bench_config/_get_started/clip_simple_average.yaml +23 -0
- fusion_bench_config/_get_started/clip_task_arithmetic.yaml +24 -0
- fusion_bench_config/_get_started/greeting_program.yaml +4 -0
- fusion_bench_config/fabric/loggers/csv_logger.yaml +3 -3
- fusion_bench_config/fabric/loggers/tensorboard_logger.yaml +3 -3
- fusion_bench_config/fabric_model_fusion.yaml +45 -17
- fusion_bench_config/hydra/default.yaml +6 -2
- fusion_bench_config/llama_full_finetune.yaml +1 -0
- fusion_bench_config/method/adamerging/clip.yaml +1 -1
- fusion_bench_config/method/bitdelta/bitdelta.yaml +12 -0
- fusion_bench_config/method/depth_upscaling.yaml +4 -1
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +2 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +1 -0
- fusion_bench_config/modelpool/CLIPVisionModelPool/_template.yaml +1 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +4 -9
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_finetuned.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -6
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_TA8_model_only.yaml +1 -1
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_math_and_coder.yaml +2 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-7B-math_and_coder.yaml +9 -0
- fusion_bench_config/modelpool/CausalLMPool/mistral-7b.yaml +6 -0
- fusion_bench_config/modelpool/CausalLMPool/mixtral_moe_merging.yaml +10 -0
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +4 -12
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +6 -16
- fusion_bench_config/modelpool/CausalLMPool/vicuna-7b-v1.5.yaml +8 -0
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/llama_preference700k.yaml +1 -1
- fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/single_reward_model.yaml +1 -1
- fusion_bench_config/nyuv2_config.yaml +3 -1
- fusion_bench_config/nyuv2_mtl_train.yaml +1 -0
- fusion_bench_config/path/default.yaml +28 -0
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-base-patch32_svhn_and_mnist.yaml +24 -0
- fusion_bench_config/method/adamerging.yaml +0 -23
- fusion_bench_config/modelpool/mixtral_moe_merging.yaml +0 -14
- fusion_bench_config/modelpool/mixtral_moe_upscaling.yaml +0 -6
- fusion_bench_config/taskpool/clip-vit-base-patch32_svhn_and_mnist.yaml +0 -22
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.20.dist-info → fusion_bench-0.2.21.dist-info}/top_level.txt +0 -0
- /fusion_bench_config/modelpool/{SeqenceClassificationModelPool → SequenceClassificationModelPool}/roberta-base_glue.yaml +0 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from fusion_bench import BaseAlgorithm, BaseModelPool, auto_register_config
|
|
8
|
+
from fusion_bench.mixins import LightningFabricMixin, SimpleProfilerMixin
|
|
9
|
+
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ProjectedEnergyAnalysis(
|
|
14
|
+
SimpleProfilerMixin,
|
|
15
|
+
LightningFabricMixin,
|
|
16
|
+
BaseAlgorithm,
|
|
17
|
+
):
|
|
18
|
+
def on_run_start(self):
|
|
19
|
+
self.device = self.fabric.device
|
|
20
|
+
|
|
21
|
+
def run(self, modelpool: BaseModelPool):
|
|
22
|
+
with self.profile("model loading"):
|
|
23
|
+
base_model = modelpool.load_pretrained_model()
|
|
24
|
+
|
|
25
|
+
results = {
|
|
26
|
+
"model_name": [],
|
|
27
|
+
"module_index": [],
|
|
28
|
+
"module_name": [],
|
|
29
|
+
"projected_energy_I": [],
|
|
30
|
+
"projected_energy_II": [],
|
|
31
|
+
"projected_energy_II_III": [],
|
|
32
|
+
}
|
|
33
|
+
for model_name in tqdm(
|
|
34
|
+
modelpool.model_names,
|
|
35
|
+
"analyzing",
|
|
36
|
+
dynamic_ncols=True,
|
|
37
|
+
):
|
|
38
|
+
with self.profile("model loading"):
|
|
39
|
+
finetuned_model = modelpool.load_model(model_name)
|
|
40
|
+
|
|
41
|
+
module_index = 0
|
|
42
|
+
for module_name, base_module in tqdm(
|
|
43
|
+
list(base_model.named_modules()),
|
|
44
|
+
"analyzing modules",
|
|
45
|
+
dynamic_ncols=True,
|
|
46
|
+
):
|
|
47
|
+
if isinstance(base_module, torch.nn.Linear):
|
|
48
|
+
with self.profile("weight analysis"):
|
|
49
|
+
_result = self.analyze_weight(
|
|
50
|
+
base_module.weight,
|
|
51
|
+
finetuned_model.get_submodule(module_name).weight,
|
|
52
|
+
)
|
|
53
|
+
results["model_name"].append(model_name)
|
|
54
|
+
results["module_index"].append(module_index)
|
|
55
|
+
results["module_name"].append(module_name)
|
|
56
|
+
for key, value in _result.items():
|
|
57
|
+
results[key].append(value)
|
|
58
|
+
|
|
59
|
+
module_index += 1
|
|
60
|
+
|
|
61
|
+
# save results as csv
|
|
62
|
+
results = pd.DataFrame(results)
|
|
63
|
+
results.to_csv(
|
|
64
|
+
os.path.join(self.log_dir, "projected_energy_analysis.csv"), index=True
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
self.print_profile_summary()
|
|
68
|
+
return None
|
|
69
|
+
|
|
70
|
+
@torch.no_grad()
|
|
71
|
+
def analyze_weight(self, w: torch.Tensor, w_ft: torch.Tensor, k: int = -1):
|
|
72
|
+
w = w.to(dtype=torch.float32, device=self.device)
|
|
73
|
+
w_ft = w_ft.to(dtype=torch.float32, device=self.device)
|
|
74
|
+
w_diff = w_ft - w
|
|
75
|
+
|
|
76
|
+
# Perform analysis on the weight tensor
|
|
77
|
+
u, s, vh = torch.linalg.svd(w, full_matrices=False)
|
|
78
|
+
v = vh.T
|
|
79
|
+
if k < 0:
|
|
80
|
+
# find the position where the sum of singular values is larger than 50% of the total sum
|
|
81
|
+
cumsum = s.cumsum(0)
|
|
82
|
+
k = (cumsum < cumsum[-1] * 0.5).sum().item() + 1
|
|
83
|
+
|
|
84
|
+
# subspace I
|
|
85
|
+
w_diff_proj = self._project_subspace_low(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
86
|
+
projected_energy_I = (
|
|
87
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
88
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# subspace II
|
|
92
|
+
w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
93
|
+
projected_energy_II = (
|
|
94
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
95
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
## subspace II+III
|
|
99
|
+
u, s, vh = torch.linalg.svd(w, full_matrices=True)
|
|
100
|
+
v = vh.T
|
|
101
|
+
w_diff_proj = self._project_subspace_high(u=u, s=s, v=v, k=k, w=w, w_ft=w_ft)
|
|
102
|
+
projected_energy_II_III = (
|
|
103
|
+
torch.linalg.norm(w_diff_proj, ord="fro") ** 2
|
|
104
|
+
/ torch.linalg.norm(w_diff, ord="fro") ** 2
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
return {
|
|
108
|
+
"projected_energy_I": projected_energy_I.item(),
|
|
109
|
+
"projected_energy_II": projected_energy_II.item(),
|
|
110
|
+
"projected_energy_II_III": projected_energy_II_III.item(),
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
def _project_subspace_low(
|
|
114
|
+
self,
|
|
115
|
+
u: torch.Tensor,
|
|
116
|
+
s: torch.Tensor,
|
|
117
|
+
v: torch.Tensor,
|
|
118
|
+
k: int,
|
|
119
|
+
w: torch.Tensor,
|
|
120
|
+
w_ft: torch.Tensor,
|
|
121
|
+
):
|
|
122
|
+
u = u[:, :k]
|
|
123
|
+
s = s[:k]
|
|
124
|
+
v = v[:, :k]
|
|
125
|
+
|
|
126
|
+
w_diff = w_ft - w
|
|
127
|
+
w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
|
|
128
|
+
return w_diff_proj
|
|
129
|
+
|
|
130
|
+
def _project_subspace_high(
|
|
131
|
+
self,
|
|
132
|
+
u: torch.Tensor,
|
|
133
|
+
s: torch.Tensor,
|
|
134
|
+
v: torch.Tensor,
|
|
135
|
+
k: int,
|
|
136
|
+
w: torch.Tensor,
|
|
137
|
+
w_ft: torch.Tensor,
|
|
138
|
+
):
|
|
139
|
+
u = u[:, k:]
|
|
140
|
+
s = s[k:]
|
|
141
|
+
v = v[:, k:]
|
|
142
|
+
|
|
143
|
+
w_diff = w_ft - w
|
|
144
|
+
w_diff_proj = torch.linalg.multi_dot((u, u.T, w_diff, v, v.T))
|
|
145
|
+
return w_diff_proj
|
|
@@ -16,10 +16,16 @@ from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
|
|
|
16
16
|
|
|
17
17
|
from fusion_bench import BaseAlgorithm, BaseModelPool
|
|
18
18
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
19
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
20
|
+
from fusion_bench.modelpool import CausalLMPool
|
|
21
|
+
from fusion_bench.models.hf_utils import (
|
|
22
|
+
generate_complete_readme,
|
|
23
|
+
save_pretrained_with_remote_code,
|
|
24
|
+
)
|
|
20
25
|
from fusion_bench.models.modeling_smile_qwen2 import (
|
|
21
26
|
SmileQwen2Config,
|
|
22
27
|
SmileQwen2ForCausalLM,
|
|
28
|
+
SmileQwen2Model,
|
|
23
29
|
)
|
|
24
30
|
from fusion_bench.models.modeling_smile_qwen2.modeling_smile_qwen2 import (
|
|
25
31
|
SmileQwen2DecoderLayer,
|
|
@@ -34,6 +40,7 @@ from fusion_bench.utils.parameters import print_parameters
|
|
|
34
40
|
log = logging.getLogger(__name__)
|
|
35
41
|
|
|
36
42
|
|
|
43
|
+
@auto_register_config
|
|
37
44
|
class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
38
45
|
R"""
|
|
39
46
|
SmileQwen2UpscalingAlgorithm is a model fusion algorithm designed to upscale
|
|
@@ -49,15 +56,7 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
49
56
|
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
50
57
|
"""
|
|
51
58
|
|
|
52
|
-
|
|
53
|
-
"device": "device",
|
|
54
|
-
"accelerator": "accelerator",
|
|
55
|
-
"model_path": "model_path",
|
|
56
|
-
"model_dtype": "model_dtype",
|
|
57
|
-
"num_experts_per_tok": "num_experts_per_tok",
|
|
58
|
-
"rank_of_router": "rank_of_router",
|
|
59
|
-
"rank_of_expert": "rank_of_expert",
|
|
60
|
-
}
|
|
59
|
+
modelpool: CausalLMPool
|
|
61
60
|
|
|
62
61
|
def __init__(
|
|
63
62
|
self,
|
|
@@ -68,20 +67,13 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
68
67
|
num_experts_per_tok,
|
|
69
68
|
rank_of_router,
|
|
70
69
|
rank_of_expert,
|
|
70
|
+
save_with_remote_code: bool = True,
|
|
71
71
|
**kwargs,
|
|
72
72
|
):
|
|
73
|
-
self.device = device
|
|
74
|
-
self.accelerator = accelerator
|
|
75
|
-
self.model_path = model_path
|
|
76
|
-
self.model_dtype = model_dtype
|
|
77
|
-
# SmileMoE parameters, except `num_local_experts` which is set later according to the number of finetuned models
|
|
78
|
-
self.num_experts_per_tok = num_experts_per_tok
|
|
79
|
-
self.rank_of_router = rank_of_router
|
|
80
|
-
self.rank_of_expert = rank_of_expert
|
|
81
73
|
super().__init__(**kwargs)
|
|
82
74
|
|
|
83
75
|
@torch.no_grad()
|
|
84
|
-
def run(self, modelpool
|
|
76
|
+
def run(self, modelpool) -> SmileQwen2ForCausalLM:
|
|
85
77
|
"""
|
|
86
78
|
Executes the upscaling process.
|
|
87
79
|
|
|
@@ -129,13 +121,29 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
129
121
|
if os.path.dirname(config.model_path):
|
|
130
122
|
os.makedirs(os.path.dirname(config.model_path), exist_ok=True)
|
|
131
123
|
log.info(f"Saving model to {config.model_path}")
|
|
132
|
-
|
|
133
|
-
pretrained_path = pretrained_model_config.get(
|
|
134
|
-
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
135
|
-
)
|
|
136
|
-
tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
|
124
|
+
tokenizer = self.modelpool.load_tokenizer()
|
|
137
125
|
tokenizer.save_pretrained(config.model_path)
|
|
138
|
-
|
|
126
|
+
if not self.save_with_remote_code:
|
|
127
|
+
model.save_pretrained(config.model_path)
|
|
128
|
+
else:
|
|
129
|
+
save_pretrained_with_remote_code(
|
|
130
|
+
model,
|
|
131
|
+
auto_map={
|
|
132
|
+
"AutoConfig": SmileQwen2Config,
|
|
133
|
+
"AutoModel": SmileQwen2Model,
|
|
134
|
+
"AutoModelForCausalLM": SmileQwen2ForCausalLM,
|
|
135
|
+
},
|
|
136
|
+
save_directory=config.model_path,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# save readme
|
|
140
|
+
complete_readme = generate_complete_readme(
|
|
141
|
+
algorithm=self,
|
|
142
|
+
modelpool=modelpool,
|
|
143
|
+
models=[modelpool.get_model_path(m) for m in modelpool.all_model_names],
|
|
144
|
+
)
|
|
145
|
+
with open(os.path.join(config.model_path, "README.md"), "w") as f:
|
|
146
|
+
f.write(complete_readme)
|
|
139
147
|
|
|
140
148
|
return model
|
|
141
149
|
|
|
@@ -158,9 +166,12 @@ class SmileQwen2UpscalingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
158
166
|
|
|
159
167
|
with init_empty_weights():
|
|
160
168
|
pretrained_model_config = self.modelpool.get_model_config("_pretrained_")
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
169
|
+
if isinstance(pretrained_model_config, str):
|
|
170
|
+
pretrained_path = pretrained_model_config
|
|
171
|
+
else:
|
|
172
|
+
pretrained_path = pretrained_model_config.get(
|
|
173
|
+
"path", pretrained_model_config["pretrained_model_name_or_path"]
|
|
174
|
+
)
|
|
164
175
|
base_config = AutoConfig.from_pretrained(pretrained_path)
|
|
165
176
|
model_config = SmileQwen2Config(
|
|
166
177
|
num_experts_per_tok=config.num_experts_per_tok,
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
import os
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import Dict, List, Tuple # noqa: F401
|
|
4
|
+
from typing import Any, Dict, List, Tuple # noqa: F401
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
import torch.nn.functional as F
|
|
@@ -21,6 +21,7 @@ from fusion_bench.models.smile_moe.linear_from_module import (
|
|
|
21
21
|
)
|
|
22
22
|
from fusion_bench.models.utils import get_attr, set_attr
|
|
23
23
|
from fusion_bench.utils.parameters import print_parameters
|
|
24
|
+
from fusion_bench.utils.devices import get_device
|
|
24
25
|
|
|
25
26
|
log = logging.getLogger(__name__)
|
|
26
27
|
|
|
@@ -54,7 +55,7 @@ class SmileUpscalingAlgorithm(
|
|
|
54
55
|
routing_use_diff: bool = True,
|
|
55
56
|
average_experts: bool = False,
|
|
56
57
|
model_path: str = None,
|
|
57
|
-
**kwargs,
|
|
58
|
+
**kwargs: Any,
|
|
58
59
|
):
|
|
59
60
|
"""
|
|
60
61
|
Initialize the SmileUpscalingAlgorithm.
|
|
@@ -91,7 +92,7 @@ class SmileUpscalingAlgorithm(
|
|
|
91
92
|
print(f"=== Config for `{type(self).__name__}` ===")
|
|
92
93
|
|
|
93
94
|
@torch.no_grad()
|
|
94
|
-
def run(self, modelpool: BaseModelPool):
|
|
95
|
+
def run(self, modelpool: BaseModelPool) -> nn.Module:
|
|
95
96
|
"""
|
|
96
97
|
Executes the upscaling process.
|
|
97
98
|
|
|
@@ -142,7 +143,7 @@ class SmileUpscalingAlgorithm(
|
|
|
142
143
|
pretrained_model: nn.Module,
|
|
143
144
|
finetuned_models: List[nn.Module],
|
|
144
145
|
in_place: bool = True,
|
|
145
|
-
):
|
|
146
|
+
) -> nn.Module:
|
|
146
147
|
"""
|
|
147
148
|
Merges the pretrained model with the fine-tuned models to create an upscaled model.
|
|
148
149
|
|
|
@@ -180,7 +181,12 @@ class SmileUpscalingAlgorithm(
|
|
|
180
181
|
|
|
181
182
|
name_list = name.split(".")
|
|
182
183
|
module = get_attr(pretrained_model, name_list)
|
|
183
|
-
|
|
184
|
+
original_device = get_device(module)
|
|
185
|
+
module = module.to(self.device, non_blocking=True)
|
|
186
|
+
experts = [
|
|
187
|
+
get_attr(m, name_list).to(self.device, non_blocking=True)
|
|
188
|
+
for m in finetuned_models
|
|
189
|
+
]
|
|
184
190
|
try:
|
|
185
191
|
moe_linear = SmileMoELinear(
|
|
186
192
|
module,
|
|
@@ -192,6 +198,7 @@ class SmileUpscalingAlgorithm(
|
|
|
192
198
|
full_matrices=self.full_matrices,
|
|
193
199
|
upscaling_accelerator=self.upscaling_accelerator,
|
|
194
200
|
)
|
|
201
|
+
moe_linear = moe_linear.to(original_device, non_blocking=True)
|
|
195
202
|
except ExpertNotTrainedError:
|
|
196
203
|
print(f"skip {name} because the experts are not trained.")
|
|
197
204
|
return
|
|
@@ -9,7 +9,7 @@ from copy import deepcopy
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
from fusion_bench import BaseAlgorithm
|
|
12
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
12
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
13
13
|
from fusion_bench.modelpool import BaseModelPool
|
|
14
14
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
15
15
|
state_dict_add,
|
|
@@ -58,16 +58,11 @@ def generate_task_masks(
|
|
|
58
58
|
return final_mask
|
|
59
59
|
|
|
60
60
|
|
|
61
|
+
@auto_register_config
|
|
61
62
|
class TallMaskTaskArithmeticAlgorithm(
|
|
62
|
-
BaseAlgorithm,
|
|
63
63
|
SimpleProfilerMixin,
|
|
64
|
+
BaseAlgorithm,
|
|
64
65
|
):
|
|
65
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
66
|
-
"tall_mask_lambda": "tall_mask_lambda",
|
|
67
|
-
"debug": "debug",
|
|
68
|
-
"verbose": "verbose",
|
|
69
|
-
}
|
|
70
|
-
|
|
71
66
|
def __init__(
|
|
72
67
|
self,
|
|
73
68
|
tall_mask_lambda: float,
|
|
@@ -76,9 +71,6 @@ class TallMaskTaskArithmeticAlgorithm(
|
|
|
76
71
|
**kwargs,
|
|
77
72
|
):
|
|
78
73
|
super().__init__(**kwargs)
|
|
79
|
-
self.tall_mask_lambda = tall_mask_lambda
|
|
80
|
-
self.debug = debug
|
|
81
|
-
self.verbose = verbose
|
|
82
74
|
|
|
83
75
|
@torch.no_grad()
|
|
84
76
|
def run(self, modelpool: BaseModelPool):
|
|
@@ -12,7 +12,7 @@ import torch
|
|
|
12
12
|
from torch import nn
|
|
13
13
|
|
|
14
14
|
from fusion_bench.method.base_algorithm import BaseAlgorithm
|
|
15
|
-
from fusion_bench.mixins
|
|
15
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
16
16
|
from fusion_bench.modelpool import BaseModelPool
|
|
17
17
|
from fusion_bench.utils.state_dict_arithmetic import (
|
|
18
18
|
state_dict_add,
|
|
@@ -74,9 +74,10 @@ def task_arithmetic_merge(
|
|
|
74
74
|
return pretrained_model
|
|
75
75
|
|
|
76
76
|
|
|
77
|
+
@auto_register_config
|
|
77
78
|
class TaskArithmeticAlgorithm(
|
|
78
|
-
BaseAlgorithm,
|
|
79
79
|
SimpleProfilerMixin,
|
|
80
|
+
BaseAlgorithm,
|
|
80
81
|
):
|
|
81
82
|
"""
|
|
82
83
|
Task Arithmetic Algorithm for model fusion.
|
|
@@ -89,22 +90,17 @@ class TaskArithmeticAlgorithm(
|
|
|
89
90
|
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
90
91
|
"""
|
|
91
92
|
|
|
92
|
-
|
|
93
|
-
"scaling_factor": "scaling_factor"
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
def __init__(self, scaling_factor: int):
|
|
93
|
+
def __init__(self, scaling_factor: int, **kwargs):
|
|
97
94
|
"""
|
|
98
95
|
Initializes the TaskArithmeticAlgorithm with the given scaling factor.
|
|
99
96
|
|
|
100
97
|
Args:
|
|
101
98
|
scaling_factor (int): The factor by which the task vectors will be scaled before merging.
|
|
102
99
|
"""
|
|
103
|
-
|
|
104
|
-
super().__init__()
|
|
100
|
+
super().__init__(**kwargs)
|
|
105
101
|
|
|
106
102
|
@torch.no_grad()
|
|
107
|
-
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]):
|
|
103
|
+
def run(self, modelpool: Union[BaseModelPool, Dict[str, nn.Module]]) -> nn.Module:
|
|
108
104
|
"""
|
|
109
105
|
Runs the Task Arithmetic Algorithm to fuse models in the given model pool.
|
|
110
106
|
|
|
@@ -9,14 +9,14 @@ Overview of Ties-Merging:
|
|
|
9
9
|
"""
|
|
10
10
|
|
|
11
11
|
import logging
|
|
12
|
-
from typing import Dict, List, Literal, Mapping, Union # noqa: F401
|
|
12
|
+
from typing import Any, Dict, List, Literal, Mapping, Union # noqa: F401
|
|
13
13
|
|
|
14
14
|
import torch
|
|
15
15
|
from torch import Tensor, nn
|
|
16
16
|
|
|
17
17
|
from fusion_bench.compat.modelpool import to_modelpool
|
|
18
18
|
from fusion_bench.method import BaseAlgorithm
|
|
19
|
-
from fusion_bench.mixins import SimpleProfilerMixin
|
|
19
|
+
from fusion_bench.mixins import SimpleProfilerMixin, auto_register_config
|
|
20
20
|
from fusion_bench.modelpool import BaseModelPool
|
|
21
21
|
from fusion_bench.utils.type import StateDictType
|
|
22
22
|
|
|
@@ -25,33 +25,22 @@ from .ties_merging_utils import state_dict_to_vector, ties_merging, vector_to_st
|
|
|
25
25
|
log = logging.getLogger(__name__)
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
scaling_factor (float): The scaling factor to apply to the merged task vector.
|
|
34
|
-
threshold (float): The threshold for resetting values in the task vector.
|
|
35
|
-
remove_keys (List[str]): List of keys to remove from the state dictionary.
|
|
36
|
-
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
37
|
-
"""
|
|
38
|
-
|
|
39
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
40
|
-
"scaling_factor": "scaling_factor",
|
|
41
|
-
"threshold": "threshold",
|
|
42
|
-
"remove_keys": "remove_keys",
|
|
43
|
-
"merge_func": "merge_func",
|
|
44
|
-
}
|
|
45
|
-
|
|
28
|
+
@auto_register_config
|
|
29
|
+
class TiesMergingAlgorithm(
|
|
30
|
+
SimpleProfilerMixin,
|
|
31
|
+
BaseAlgorithm,
|
|
32
|
+
):
|
|
46
33
|
def __init__(
|
|
47
34
|
self,
|
|
48
35
|
scaling_factor: float,
|
|
49
36
|
threshold: float,
|
|
50
37
|
remove_keys: List[str],
|
|
51
38
|
merge_func: Literal["sum", "mean", "max"],
|
|
52
|
-
**kwargs,
|
|
39
|
+
**kwargs: Any,
|
|
53
40
|
):
|
|
54
41
|
"""
|
|
42
|
+
TiesMergingAlgorithm is a class for fusing multiple models using the TIES merging technique.
|
|
43
|
+
|
|
55
44
|
Initialize the TiesMergingAlgorithm with the given parameters.
|
|
56
45
|
|
|
57
46
|
Args:
|
|
@@ -61,14 +50,12 @@ class TiesMergingAlgorithm(BaseAlgorithm, SimpleProfilerMixin):
|
|
|
61
50
|
merge_func (Literal["sum", "mean", "max"]): The merge function to use for disjoint merging.
|
|
62
51
|
**kwargs: Additional keyword arguments for the base class.
|
|
63
52
|
"""
|
|
64
|
-
self.scaling_factor = scaling_factor
|
|
65
|
-
self.threshold = threshold
|
|
66
|
-
self.remove_keys = remove_keys
|
|
67
|
-
self.merge_func = merge_func
|
|
68
53
|
super().__init__(**kwargs)
|
|
69
54
|
|
|
70
55
|
@torch.no_grad()
|
|
71
|
-
def run(
|
|
56
|
+
def run(
|
|
57
|
+
self, modelpool: BaseModelPool | Dict[str, nn.Module], **kwargs: Any
|
|
58
|
+
) -> nn.Module:
|
|
72
59
|
"""
|
|
73
60
|
Run the TIES merging algorithm to fuse models in the model pool.
|
|
74
61
|
|
|
@@ -2,6 +2,7 @@ import functools
|
|
|
2
2
|
import logging
|
|
3
3
|
import os
|
|
4
4
|
from copy import deepcopy
|
|
5
|
+
from typing import Any, Iterator
|
|
5
6
|
|
|
6
7
|
import torch
|
|
7
8
|
from torch import Tensor
|
|
@@ -38,7 +39,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
38
39
|
|
|
39
40
|
modelpool: CLIPVisionModelPool = None
|
|
40
41
|
|
|
41
|
-
def load_checkpoint(self, model, checkpoint):
|
|
42
|
+
def load_checkpoint(self, model: Any, checkpoint: Any):
|
|
42
43
|
"""
|
|
43
44
|
Load the checkpoint file.
|
|
44
45
|
|
|
@@ -49,7 +50,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
49
50
|
state = {"model": model}
|
|
50
51
|
self._fabric.load(checkpoint, state)
|
|
51
52
|
|
|
52
|
-
def save_checkpoint(self, model, checkpoint):
|
|
53
|
+
def save_checkpoint(self, model: Any, checkpoint: Any):
|
|
53
54
|
"""
|
|
54
55
|
Save the checkpoint file.
|
|
55
56
|
|
|
@@ -102,7 +103,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
102
103
|
return moe_model
|
|
103
104
|
|
|
104
105
|
@functools.cache
|
|
105
|
-
def get_shuffled_test_loader_iter(self, tta_dataset: str):
|
|
106
|
+
def get_shuffled_test_loader_iter(self, tta_dataset: str) -> Iterator:
|
|
106
107
|
"""
|
|
107
108
|
Get an iterator for the shuffled test data loader.
|
|
108
109
|
|
|
@@ -131,7 +132,7 @@ class CLIPWeightEnsemblingMoEAlgorithm(
|
|
|
131
132
|
"""
|
|
132
133
|
self.setup_zero_shot_classification_head()
|
|
133
134
|
|
|
134
|
-
def compute_logits(self, module, batch, task) -> Tensor:
|
|
135
|
+
def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
|
|
135
136
|
"""
|
|
136
137
|
Compute the logits for the given batch and task.
|
|
137
138
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from abc import abstractmethod
|
|
3
|
-
from typing import cast # noqa: F401
|
|
3
|
+
from typing import Any, cast # noqa: F401
|
|
4
4
|
|
|
5
5
|
import lightning as L
|
|
6
6
|
import lightning.fabric.wrappers
|
|
@@ -70,7 +70,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
70
70
|
assert "No CUDA device available."
|
|
71
71
|
|
|
72
72
|
@abstractmethod
|
|
73
|
-
def load_checkpoint(self, model, checkpoint):
|
|
73
|
+
def load_checkpoint(self, model: Any, checkpoint: Any):
|
|
74
74
|
"""
|
|
75
75
|
Load the checkpoint file.
|
|
76
76
|
|
|
@@ -81,7 +81,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
81
81
|
pass
|
|
82
82
|
|
|
83
83
|
@abstractmethod
|
|
84
|
-
def save_checkpoint(self, model, checkpoint):
|
|
84
|
+
def save_checkpoint(self, model: Any, checkpoint: Any):
|
|
85
85
|
"""
|
|
86
86
|
Save the checkpoint file.
|
|
87
87
|
|
|
@@ -121,7 +121,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
121
121
|
pass
|
|
122
122
|
|
|
123
123
|
@abstractmethod
|
|
124
|
-
def compute_logits(self, module, batch, task) -> Tensor:
|
|
124
|
+
def compute_logits(self, module: Any, batch: Any, task: Any) -> Tensor:
|
|
125
125
|
"""
|
|
126
126
|
Compute the logits for a given batch and task.
|
|
127
127
|
|
|
@@ -135,7 +135,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
135
135
|
"""
|
|
136
136
|
pass
|
|
137
137
|
|
|
138
|
-
def test_time_adaptation(self, module: WeightEnsemblingMoE):
|
|
138
|
+
def test_time_adaptation(self, module: WeightEnsemblingMoE) -> WeightEnsemblingMoE:
|
|
139
139
|
"""
|
|
140
140
|
Perform test-time adaptation for the given module.
|
|
141
141
|
|
|
@@ -208,7 +208,7 @@ class WeightEnsemblingMoEAlgorithm(
|
|
|
208
208
|
|
|
209
209
|
return module
|
|
210
210
|
|
|
211
|
-
def run(self, modelpool: ModelPool):
|
|
211
|
+
def run(self, modelpool: ModelPool) -> WeightEnsemblingMoE:
|
|
212
212
|
"""
|
|
213
213
|
Run the WeightEnsemblingMoEAlgorithm to fuse models using Weight Ensembling Mixture of Experts.
|
|
214
214
|
|
|
@@ -3,6 +3,7 @@ from typing import List, Mapping, Union # noqa: F401
|
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
import torch
|
|
6
|
+
from transformers import PreTrainedModel
|
|
6
7
|
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from fusion_bench.method import BaseAlgorithm
|
|
@@ -10,24 +11,17 @@ from fusion_bench.modelpool import CausalLMPool
|
|
|
10
11
|
from fusion_bench.utils import timeit_context
|
|
11
12
|
from fusion_bench.utils.state_dict_arithmetic import state_dict_add, state_dict_mul
|
|
12
13
|
from fusion_bench.utils.type import StateDictType
|
|
14
|
+
from fusion_bench.mixins import auto_register_config
|
|
13
15
|
|
|
14
16
|
log = logging.getLogger(__name__)
|
|
15
17
|
|
|
16
18
|
|
|
19
|
+
@auto_register_config
|
|
17
20
|
class WeightedAverageForLLama(BaseAlgorithm):
|
|
18
21
|
"""
|
|
19
22
|
A class to perform weighted averaging of LlaMa/Mistral models.
|
|
20
23
|
"""
|
|
21
24
|
|
|
22
|
-
_config_mapping = BaseAlgorithm._config_mapping | {
|
|
23
|
-
"normalize": "normalize",
|
|
24
|
-
"weights": "weights",
|
|
25
|
-
"backbone_only": "backbone_only",
|
|
26
|
-
"merged_model_save_path": "merged_model_save_path",
|
|
27
|
-
"save_tokenizer": "save_tokenizer",
|
|
28
|
-
"push_to_hub": "push_to_hub",
|
|
29
|
-
}
|
|
30
|
-
|
|
31
25
|
def __init__(
|
|
32
26
|
self,
|
|
33
27
|
normalize: bool,
|
|
@@ -49,17 +43,11 @@ class WeightedAverageForLLama(BaseAlgorithm):
|
|
|
49
43
|
save_tokenizer (bool): Whether to save the tokenizer.
|
|
50
44
|
push_to_hub (bool): Whether to push the model to the hub.
|
|
51
45
|
"""
|
|
52
|
-
self.normalize = normalize
|
|
53
|
-
self.weights = weights
|
|
54
|
-
self.backbone_only = backbone_only
|
|
55
|
-
self.merged_model_save_path = merged_model_save_path
|
|
56
|
-
self.save_tokenizer = save_tokenizer
|
|
57
|
-
self.push_to_hub = push_to_hub
|
|
58
46
|
super().__init__(**kwargs)
|
|
59
47
|
|
|
60
48
|
@override
|
|
61
49
|
@torch.no_grad()
|
|
62
|
-
def run(self, modelpool: CausalLMPool):
|
|
50
|
+
def run(self, modelpool: CausalLMPool) -> PreTrainedModel:
|
|
63
51
|
"""
|
|
64
52
|
Executes the weighted averaging of models in the provided model pool.
|
|
65
53
|
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .backward_transfer import compute_backward_transfer
|
|
@@ -10,7 +10,7 @@ def compute_backward_transfer(
|
|
|
10
10
|
Compute the backward transfer (BWT) of a model on a set of tasks.
|
|
11
11
|
|
|
12
12
|
Equation:
|
|
13
|
-
BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{
|
|
13
|
+
$BWT = \frac{1}{n} \sum_{k=1}^{n} (acc_{T,i}[k] - acc_{i,i}[k])$
|
|
14
14
|
|
|
15
15
|
Returns:
|
|
16
16
|
float: The backward transfer of the model.
|
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
from .depth import DepthMetric
|
|
2
2
|
from .noise import NoiseMetric
|
|
3
3
|
from .normal import NormalMetric
|
|
4
|
-
from .segmentation import
|
|
4
|
+
from .segmentation import SegmentationMetric
|
|
5
5
|
|
|
6
6
|
metric_classes = {
|
|
7
|
-
"segmentation":
|
|
7
|
+
"segmentation": SegmentationMetric,
|
|
8
8
|
"depth": DepthMetric,
|
|
9
9
|
"normal": NormalMetric,
|
|
10
10
|
"noise": NoiseMetric,
|