fusion-bench 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +2 -0
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/fer2013.py +1 -0
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +10 -0
- fusion_bench/method/ada_svd/clip_vision.py +4 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +16 -7
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_mistral_upscaling.py +46 -145
- fusion_bench/method/smile_upscaling/smile_qwen2_upscaling.py +229 -0
- fusion_bench/method/smile_upscaling/smile_upscaling.py +19 -346
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/modeling_smile_mistral/modeling_smile_mistral.py +2 -203
- fusion_bench/models/modeling_smile_qwen2/__init__.py +8 -0
- fusion_bench/models/modeling_smile_qwen2/configuration_smile_qwen2.py +21 -0
- fusion_bench/models/modeling_smile_qwen2/modeling_smile_qwen2.py +922 -0
- fusion_bench/models/modeling_smile_qwen2/register.py +11 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/rankone_moe.py +2 -88
- fusion_bench/models/smile_moe/linear_from_hf_config.py +373 -0
- fusion_bench/models/smile_moe/{linear.py → linear_from_module.py} +103 -33
- fusion_bench/models/smile_moe/utils/__init__.py +24 -0
- fusion_bench/models/smile_moe/utils/svd_utils.py +46 -0
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +7 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/lm_eval_harness/__init__.py +3 -0
- fusion_bench/taskpool/lm_eval_harness/taskpool.py +87 -0
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/METADATA +22 -2
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/RECORD +209 -157
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/doge_ta/doge_ta.yaml +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -2
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +13 -0
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/qwen2_math_1.5B_and_R1.yaml +17 -0
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +1 -3
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/LMEvalHarnessTaskPool/lm_eval.yaml +12 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.12.dist-info → fusion_bench-0.2.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
|
6
|
+
|
|
7
|
+
import lightning.fabric
|
|
8
|
+
import open_clip
|
|
9
|
+
import torch
|
|
10
|
+
from omegaconf import DictConfig
|
|
11
|
+
from torch.nn import functional as F
|
|
12
|
+
from torch.utils.data import DataLoader, Dataset
|
|
13
|
+
from torchmetrics import Accuracy, MeanMetric
|
|
14
|
+
from torchmetrics.classification.accuracy import MulticlassAccuracy
|
|
15
|
+
from tqdm.auto import tqdm
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseTaskPool
|
|
18
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
19
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
20
|
+
from fusion_bench.modelpool.openclip_vision.modelpool import load_classifier_head
|
|
21
|
+
from fusion_bench.models.open_clip import (
|
|
22
|
+
ClassificationHead,
|
|
23
|
+
ImageClassifier,
|
|
24
|
+
ImageEncoder,
|
|
25
|
+
)
|
|
26
|
+
from fusion_bench.models.open_clip.variables_and_paths import OPENCLIP_CACHEDIR
|
|
27
|
+
from fusion_bench.utils import count_parameters, instantiate
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
31
|
+
from fusion_bench.programs import FabricModelFusionProgram
|
|
32
|
+
|
|
33
|
+
log = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OpenCLIPVisionModelTaskPool(
|
|
37
|
+
BaseTaskPool,
|
|
38
|
+
LightningFabricMixin,
|
|
39
|
+
):
|
|
40
|
+
_is_setup = False
|
|
41
|
+
|
|
42
|
+
_program: "FabricModelFusionProgram"
|
|
43
|
+
|
|
44
|
+
processor: Optional[Callable] = None
|
|
45
|
+
test_datasets: Dict[str, CLIPDataset]
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
test_datasets: Union[DictConfig, Dict[str, Dataset]],
|
|
50
|
+
classification_heads: Union[DictConfig, Dict[str, ClassificationHead]],
|
|
51
|
+
dataloader_kwargs: DictConfig,
|
|
52
|
+
model_name: Optional[str] = None,
|
|
53
|
+
fast_dev_run: bool = False,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
self._test_datasets = test_datasets
|
|
57
|
+
self._classifier_heads = classification_heads
|
|
58
|
+
self._dataloader_kwargs = dataloader_kwargs
|
|
59
|
+
self._model_name = model_name
|
|
60
|
+
self.fast_dev_run = fast_dev_run
|
|
61
|
+
super().__init__(**kwargs)
|
|
62
|
+
|
|
63
|
+
def setup(self):
|
|
64
|
+
# setup the processor
|
|
65
|
+
if self._program is not None and self._program.modelpool is not None:
|
|
66
|
+
modelpool: "OpenCLIPVisionModelPool" = self._program.modelpool
|
|
67
|
+
self.processor = modelpool.test_processor
|
|
68
|
+
elif self._model_name is not None:
|
|
69
|
+
_, _, self.processor = open_clip.create_model_and_transforms(
|
|
70
|
+
self._model_name,
|
|
71
|
+
pretrained="openai",
|
|
72
|
+
cache_dir=OPENCLIP_CACHEDIR,
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError("Modelpool or model_name is not set")
|
|
76
|
+
|
|
77
|
+
# setup the test datasets
|
|
78
|
+
self.test_datasets = {
|
|
79
|
+
name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
|
|
80
|
+
for name, dataset in self._test_datasets.items()
|
|
81
|
+
}
|
|
82
|
+
self.test_datasets = {
|
|
83
|
+
name: CLIPDataset(dataset, self.processor)
|
|
84
|
+
for name, dataset in self.test_datasets.items()
|
|
85
|
+
}
|
|
86
|
+
self.test_dataloaders = {
|
|
87
|
+
name: self.fabric.setup_dataloaders(
|
|
88
|
+
DataLoader(dataset, **self._dataloader_kwargs)
|
|
89
|
+
)
|
|
90
|
+
for name, dataset in self.test_datasets.items()
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# setup classifier heads
|
|
94
|
+
self.classifier_heads = {
|
|
95
|
+
name: load_classifier_head(head).to(self.fabric.device)
|
|
96
|
+
for name, head in self._classifier_heads.items()
|
|
97
|
+
}
|
|
98
|
+
self._is_setup = True
|
|
99
|
+
|
|
100
|
+
@torch.no_grad()
|
|
101
|
+
def _evaluate(
|
|
102
|
+
self,
|
|
103
|
+
classifier: ImageClassifier,
|
|
104
|
+
test_loader: DataLoader,
|
|
105
|
+
num_classes: int,
|
|
106
|
+
task_name: str,
|
|
107
|
+
):
|
|
108
|
+
accuracy: MulticlassAccuracy = Accuracy(
|
|
109
|
+
task="multiclass", num_classes=num_classes
|
|
110
|
+
)
|
|
111
|
+
classifier.eval()
|
|
112
|
+
loss_metric = MeanMetric()
|
|
113
|
+
# if fast_dev_run is set, we only evaluate on a batch of the data
|
|
114
|
+
if self.fast_dev_run:
|
|
115
|
+
log.info("Running under fast_dev_run mode, evaluating on a single batch.")
|
|
116
|
+
test_loader = itertools.islice(test_loader, 1)
|
|
117
|
+
else:
|
|
118
|
+
test_loader = test_loader
|
|
119
|
+
|
|
120
|
+
pbar = tqdm(
|
|
121
|
+
test_loader,
|
|
122
|
+
desc=f"Evaluating {task_name}",
|
|
123
|
+
leave=False,
|
|
124
|
+
dynamic_ncols=True,
|
|
125
|
+
)
|
|
126
|
+
for batch in pbar:
|
|
127
|
+
inputs, targets = batch
|
|
128
|
+
logits = classifier(inputs)
|
|
129
|
+
loss = F.cross_entropy(logits, targets)
|
|
130
|
+
loss_metric.update(loss.detach().cpu())
|
|
131
|
+
acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
|
|
132
|
+
pbar.set_postfix(
|
|
133
|
+
{
|
|
134
|
+
"accuracy": accuracy.compute().item(),
|
|
135
|
+
"loss": loss_metric.compute().item(),
|
|
136
|
+
}
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
acc = accuracy.compute().item()
|
|
140
|
+
loss = loss_metric.compute().item()
|
|
141
|
+
results = {"accuracy": acc, "loss": loss}
|
|
142
|
+
return results
|
|
143
|
+
|
|
144
|
+
def evaluate(self, model: ImageEncoder, **kwargs):
|
|
145
|
+
if not self._is_setup:
|
|
146
|
+
self.setup()
|
|
147
|
+
|
|
148
|
+
report = {}
|
|
149
|
+
# collect basic model information
|
|
150
|
+
training_params, all_params = count_parameters(model)
|
|
151
|
+
report["model_info"] = {
|
|
152
|
+
"trainable_params": training_params,
|
|
153
|
+
"all_params": all_params,
|
|
154
|
+
"trainable_percentage": training_params / all_params,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
if not lightning.fabric.is_wrapped(model):
|
|
158
|
+
model = self.fabric.setup_module(model)
|
|
159
|
+
|
|
160
|
+
pbar = tqdm(
|
|
161
|
+
self.test_dataloaders.items(),
|
|
162
|
+
desc="Evaluating tasks",
|
|
163
|
+
total=len(self.test_dataloaders),
|
|
164
|
+
)
|
|
165
|
+
for task_name, test_dataloader in pbar:
|
|
166
|
+
classifier = ImageClassifier(model, self.classifier_heads[task_name])
|
|
167
|
+
num_classes = self.classifier_heads[task_name].weight.size(0)
|
|
168
|
+
result = self._evaluate(
|
|
169
|
+
classifier,
|
|
170
|
+
test_dataloader,
|
|
171
|
+
num_classes=num_classes,
|
|
172
|
+
task_name=task_name,
|
|
173
|
+
)
|
|
174
|
+
report[task_name] = result
|
|
175
|
+
|
|
176
|
+
# calculate the average accuracy and loss
|
|
177
|
+
if "average" not in report:
|
|
178
|
+
report["average"] = {}
|
|
179
|
+
accuracies = [
|
|
180
|
+
value["accuracy"]
|
|
181
|
+
for key, value in report.items()
|
|
182
|
+
if "accuracy" in value
|
|
183
|
+
]
|
|
184
|
+
if len(accuracies) > 0:
|
|
185
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
186
|
+
report["average"]["accuracy"] = average_accuracy
|
|
187
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
188
|
+
if len(losses) > 0:
|
|
189
|
+
average_loss = sum(losses) / len(losses)
|
|
190
|
+
report["average"]["loss"] = average_loss
|
|
191
|
+
|
|
192
|
+
log.info(f"Evaluation Result: {report}")
|
|
193
|
+
if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
|
|
194
|
+
with open(os.path.join(self.log_dir, "report.json"), "w") as fp:
|
|
195
|
+
json.dump(report, fp)
|
|
196
|
+
return report
|
fusion_bench/utils/data.py
CHANGED
|
@@ -9,6 +9,18 @@ from torch.utils.data import DataLoader, Dataset
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class InfiniteDataLoader:
|
|
12
|
+
"""
|
|
13
|
+
A wrapper class for DataLoader to create an infinite data loader.
|
|
14
|
+
This is useful in case we are only interested in the number of steps and not the number of epochs.
|
|
15
|
+
|
|
16
|
+
This class wraps a DataLoader and provides an iterator that resets
|
|
17
|
+
when the end of the dataset is reached, creating an infinite loop.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
data_loader (DataLoader): The DataLoader to wrap.
|
|
21
|
+
data_iter (iterator): An iterator over the DataLoader.
|
|
22
|
+
"""
|
|
23
|
+
|
|
12
24
|
def __init__(self, data_loader: DataLoader):
|
|
13
25
|
self.data_loader = data_loader
|
|
14
26
|
self.data_iter = iter(data_loader)
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -229,3 +229,17 @@ def cleanup_cuda():
|
|
|
229
229
|
gc.collect()
|
|
230
230
|
torch.cuda.empty_cache()
|
|
231
231
|
torch.cuda.reset_peak_memory_stats()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def print_memory_usage(print_fn=print):
|
|
235
|
+
"""
|
|
236
|
+
Print the current GPU memory usage.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
str: A string containing the allocated and cached memory in MB.
|
|
240
|
+
"""
|
|
241
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
242
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
243
|
+
print_str = f"Allocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
244
|
+
print_fn(print_str)
|
|
245
|
+
return print_str
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# Modified from Hydra
|
|
3
3
|
import copy
|
|
4
4
|
import functools
|
|
5
|
+
from contextlib import contextmanager
|
|
5
6
|
from enum import Enum
|
|
6
7
|
from textwrap import dedent
|
|
7
8
|
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
|
|
@@ -30,6 +31,17 @@ Function to be used for printing function calls.
|
|
|
30
31
|
CATCH_EXCEPTION = True
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
@contextmanager
|
|
35
|
+
def set_print_function_call(value: bool):
|
|
36
|
+
global PRINT_FUNCTION_CALL
|
|
37
|
+
old_value = PRINT_FUNCTION_CALL
|
|
38
|
+
PRINT_FUNCTION_CALL = value
|
|
39
|
+
try:
|
|
40
|
+
yield
|
|
41
|
+
finally:
|
|
42
|
+
PRINT_FUNCTION_CALL = old_value
|
|
43
|
+
|
|
44
|
+
|
|
33
45
|
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
34
46
|
if OmegaConf.is_dict(config):
|
|
35
47
|
return "_target_" in config
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from typing import Iterable
|
|
1
|
+
from typing import Iterable, List
|
|
2
2
|
|
|
3
|
-
__all__ = ["first", "has_length"]
|
|
3
|
+
__all__ = ["first", "has_length", "join_list"]
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def first(iterable: Iterable):
|
|
@@ -16,3 +16,10 @@ def has_length(dataset):
|
|
|
16
16
|
except TypeError:
|
|
17
17
|
# TypeError: len() of unsized object
|
|
18
18
|
return False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def join_list(list_of_list: List[List]):
|
|
22
|
+
ans = []
|
|
23
|
+
for item in list_of_list:
|
|
24
|
+
ans.extend(item)
|
|
25
|
+
return ans
|
fusion_bench/utils/packages.py
CHANGED
|
@@ -82,3 +82,17 @@ def import_object(abs_obj_name: str):
|
|
|
82
82
|
module_name, obj_name = abs_obj_name.rsplit(".", 1)
|
|
83
83
|
module = importlib.import_module(module_name)
|
|
84
84
|
return getattr(module, obj_name)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compare_versions(v1, v2):
|
|
88
|
+
"""Compare two version strings.
|
|
89
|
+
Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2"""
|
|
90
|
+
|
|
91
|
+
v1 = version.parse(v1)
|
|
92
|
+
v2 = version.parse(v2)
|
|
93
|
+
if v1 < v2:
|
|
94
|
+
return -1
|
|
95
|
+
elif v1 > v2:
|
|
96
|
+
return 1
|
|
97
|
+
else:
|
|
98
|
+
return 0
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -252,7 +252,7 @@ def print_parameters(
|
|
|
252
252
|
|
|
253
253
|
|
|
254
254
|
def check_parameters_all_equal(
|
|
255
|
-
list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]]
|
|
255
|
+
list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
|
|
256
256
|
) -> None:
|
|
257
257
|
"""
|
|
258
258
|
Checks if all models have the same parameters.
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: fusion_bench
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.14
|
|
4
4
|
Summary: A Comprehensive Benchmark of Deep Model Fusion
|
|
5
5
|
Author-email: Anke Tang <tang.anke@foxmail.com>
|
|
6
6
|
License: MIT License
|
|
@@ -45,6 +45,8 @@ Requires-Dist: rich
|
|
|
45
45
|
Requires-Dist: scipy
|
|
46
46
|
Requires-Dist: h5py
|
|
47
47
|
Requires-Dist: pytest
|
|
48
|
+
Provides-Extra: lm-eval-harness
|
|
49
|
+
Requires-Dist: lm-eval; extra == "lm-eval-harness"
|
|
48
50
|
Dynamic: license-file
|
|
49
51
|
|
|
50
52
|
<div align='center'>
|
|
@@ -122,7 +124,7 @@ Merging multiple expert models offers a promising approach for performing multi-
|
|
|
122
124
|
|
|
123
125
|
## Installation
|
|
124
126
|
|
|
125
|
-
|
|
127
|
+
Install from PyPI:
|
|
126
128
|
|
|
127
129
|
```bash
|
|
128
130
|
pip install fusion-bench
|
|
@@ -137,6 +139,24 @@ cd fusion_bench
|
|
|
137
139
|
pip install -e . # install the package in editable mode
|
|
138
140
|
```
|
|
139
141
|
|
|
142
|
+
### Install with [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness)
|
|
143
|
+
|
|
144
|
+
[](https://doi.org/10.5281/zenodo.10256836)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
```bash
|
|
148
|
+
pip install "fusion-bench[lm-eval-harness]"
|
|
149
|
+
```
|
|
150
|
+
|
|
151
|
+
or install from local directory
|
|
152
|
+
|
|
153
|
+
```bash
|
|
154
|
+
pip install -e ".[lm-eval-harness]"
|
|
155
|
+
```
|
|
156
|
+
|
|
157
|
+
This will install the latest version of fusion-bench and the dependencies required for LM-Eval Harness.
|
|
158
|
+
Documentation for using LM-Eval Harness within FusionBench framework can be found at [this online documentation](https://tanganke.github.io/fusion_bench/taskpool/lm_eval_harness) or in the [`docs/taskpool/lm_eval_harness.md`](docs/taskpool/lm_eval_harness.md) markdown file.
|
|
159
|
+
|
|
140
160
|
## Introduction to Deep Model Fusion
|
|
141
161
|
|
|
142
162
|
Deep model fusion is a technique that merges, ensemble, or fuse multiple deep neural networks to obtain a unified model.
|