fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +3 -1
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +102 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- fusion_bench/method/DOGE_TA/__init__.py +0 -2
- /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -10,12 +10,13 @@ _import_structure = {
|
|
|
10
10
|
"clip_vision": [
|
|
11
11
|
"CLIPVisionModelTaskPool",
|
|
12
12
|
"SparseWEMoECLIPVisionModelTaskPool",
|
|
13
|
-
"
|
|
13
|
+
"RankoneMoECLIPVisionModelTaskPool",
|
|
14
14
|
],
|
|
15
15
|
"dummy": ["DummyTaskPool"],
|
|
16
16
|
"gpt2_text_classification": ["GPT2TextClassificationTaskPool"],
|
|
17
|
-
"nyuv2_taskpool": ["NYUv2TaskPool"],
|
|
18
17
|
"llama": ["LlamaTestGenerationTaskPool"],
|
|
18
|
+
"nyuv2_taskpool": ["NYUv2TaskPool"],
|
|
19
|
+
"openclip_vision": ["OpenCLIPVisionModelTaskPool"],
|
|
19
20
|
}
|
|
20
21
|
|
|
21
22
|
|
|
@@ -23,13 +24,14 @@ if TYPE_CHECKING:
|
|
|
23
24
|
from .base_pool import BaseTaskPool
|
|
24
25
|
from .clip_vision import (
|
|
25
26
|
CLIPVisionModelTaskPool,
|
|
26
|
-
|
|
27
|
+
RankoneMoECLIPVisionModelTaskPool,
|
|
27
28
|
SparseWEMoECLIPVisionModelTaskPool,
|
|
28
29
|
)
|
|
29
30
|
from .dummy import DummyTaskPool
|
|
30
31
|
from .gpt2_text_classification import GPT2TextClassificationTaskPool
|
|
31
32
|
from .llama import LlamaTestGenerationTaskPool
|
|
32
33
|
from .nyuv2_taskpool import NYUv2TaskPool
|
|
34
|
+
from .openclip_vision import OpenCLIPVisionModelTaskPool
|
|
33
35
|
|
|
34
36
|
else:
|
|
35
37
|
sys.modules[__name__] = LazyImporter(
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
2
|
from .clip_rankone_moe_taskpool import RankoneMoECLIPVisionModelTaskPool
|
|
3
|
+
from .clip_smile_taskpool import SmileCLIPVisionModelTaskPool
|
|
3
4
|
from .clip_sparse_wemoe_taskpool import SparseWEMoECLIPVisionModelTaskPool
|
|
4
5
|
from .taskpool import CLIPVisionModelTaskPool
|
|
@@ -12,36 +12,7 @@ from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
|
12
12
|
from fusion_bench.models.rankone_moe import RankOneMoE
|
|
13
13
|
|
|
14
14
|
from .taskpool import CLIPVisionModelTaskPool
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class LayerWiseRoutingWeightSaver:
|
|
18
|
-
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
19
|
-
self.save_path = save_path
|
|
20
|
-
self.max_num = max_num
|
|
21
|
-
self.routing_weights = []
|
|
22
|
-
|
|
23
|
-
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
24
|
-
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
25
|
-
# (batch_size, num_tokens, num_experts)
|
|
26
|
-
routing_weights = output.detach().cpu()
|
|
27
|
-
if self.max_num is not None and self.max_num > 0:
|
|
28
|
-
if len(self.routing_weights) > self.max_num:
|
|
29
|
-
return
|
|
30
|
-
elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
|
|
31
|
-
self.routing_weights.append(
|
|
32
|
-
routing_weights[: self.max_num - len(self.routing_weights)]
|
|
33
|
-
)
|
|
34
|
-
else:
|
|
35
|
-
self.routing_weights.append(routing_weights)
|
|
36
|
-
else:
|
|
37
|
-
self.routing_weights.append(routing_weights)
|
|
38
|
-
|
|
39
|
-
def save_routing_weights(self):
|
|
40
|
-
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
41
|
-
if self.save_path is not None:
|
|
42
|
-
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
43
|
-
print(f"Saving routing weights to {self.save_path}")
|
|
44
|
-
torch.save(routing_weights, self.save_path)
|
|
15
|
+
from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
|
|
45
16
|
|
|
46
17
|
|
|
47
18
|
class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
@@ -109,4 +80,5 @@ class RankoneMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
|
109
80
|
# remove hooks for saving layer-wise routing weights
|
|
110
81
|
for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
111
82
|
self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
|
|
83
|
+
self._layer_wise_routing_weights_save_hook_handles.pop(i)
|
|
112
84
|
handle.remove()
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torch.utils.hooks import RemovableHandle
|
|
8
|
+
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel
|
|
9
|
+
from transformers.models.clip.modeling_clip import CLIPVisionTransformer
|
|
10
|
+
|
|
11
|
+
from fusion_bench.method.smile_upscaling import SmileMoELinear
|
|
12
|
+
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
13
|
+
|
|
14
|
+
from .taskpool import CLIPVisionModelTaskPool
|
|
15
|
+
from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SmileCLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
19
|
+
|
|
20
|
+
# hooks and handles for saving layer-wise routing weights
|
|
21
|
+
_layer_wise_routing_weights_save_hooks: Dict[Any, LayerWiseRoutingWeightSaver] = {}
|
|
22
|
+
_layer_wise_routing_weights_save_hook_handles: Dict[Any, RemovableHandle] = {}
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
linear_module_names: Union[List[str], str],
|
|
27
|
+
layer_wise_routing_weights_save_path: Optional[str],
|
|
28
|
+
layer_wise_routing_weights_max_num: Optional[int] = None,
|
|
29
|
+
**kwargs,
|
|
30
|
+
):
|
|
31
|
+
"""
|
|
32
|
+
Initialize the SMILECLIPVisionModelTaskPool.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
linear_module_names (Union[List[str], str]): The names of the linear modules to save the layer-wise routing weights for.
|
|
36
|
+
layer_wise_routing_weights_save_path (Optional[str]): The path to save the layer-wise routing weights.
|
|
37
|
+
layer_wise_routing_weights_max_num (Optional[int]): The maximum number of layer-wise routing weights to save.
|
|
38
|
+
"""
|
|
39
|
+
# linear module names
|
|
40
|
+
assert linear_module_names is not None, "linear_module_names must be provided"
|
|
41
|
+
self.linear_module_names = (
|
|
42
|
+
[linear_module_names]
|
|
43
|
+
if isinstance(linear_module_names, str)
|
|
44
|
+
else list(linear_module_names)
|
|
45
|
+
)
|
|
46
|
+
# save path for layer-wise routing weights
|
|
47
|
+
self._layer_wise_routing_weights_save_path = (
|
|
48
|
+
layer_wise_routing_weights_save_path
|
|
49
|
+
)
|
|
50
|
+
self.layer_wise_routing_weights_save_path = (
|
|
51
|
+
Path(layer_wise_routing_weights_save_path)
|
|
52
|
+
if layer_wise_routing_weights_save_path is not None
|
|
53
|
+
else None
|
|
54
|
+
)
|
|
55
|
+
self.layer_wise_routing_weights_max_num = layer_wise_routing_weights_max_num
|
|
56
|
+
super().__init__(**kwargs)
|
|
57
|
+
|
|
58
|
+
def on_task_evaluation_begin(self, classifier: HFCLIPClassifier, task_name: str):
|
|
59
|
+
super().on_task_evaluation_begin(classifier, task_name)
|
|
60
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
61
|
+
# setup hooks for saving layer-wise routing weights
|
|
62
|
+
assert isinstance(
|
|
63
|
+
classifier.clip_model.vision_model,
|
|
64
|
+
(CLIPVisionTransformer, CLIPVisionModel),
|
|
65
|
+
), "Vision model is expected to be a CLIPVisionTransformer"
|
|
66
|
+
vision_model = classifier.clip_model.vision_model
|
|
67
|
+
if isinstance(vision_model, CLIPVisionModel):
|
|
68
|
+
vision_model = vision_model.vision_model
|
|
69
|
+
# assign forward hooks for each layer
|
|
70
|
+
|
|
71
|
+
for i, layer in enumerate(vision_model.encoder.layers):
|
|
72
|
+
for linear_module_name in self.linear_module_names:
|
|
73
|
+
linear_module = layer.get_submodule(linear_module_name)
|
|
74
|
+
assert isinstance(
|
|
75
|
+
linear_module,
|
|
76
|
+
(SmileMoELinear),
|
|
77
|
+
), f"Linear module is expected to be a SmileMoELinear, but got {type(linear_module)}"
|
|
78
|
+
# layer-wise routing weights
|
|
79
|
+
hook = LayerWiseRoutingWeightSaver(
|
|
80
|
+
self.layer_wise_routing_weights_save_path
|
|
81
|
+
/ task_name
|
|
82
|
+
/ f"layer_{i}_{linear_module_name}.pt",
|
|
83
|
+
max_num=self.layer_wise_routing_weights_max_num,
|
|
84
|
+
)
|
|
85
|
+
self._layer_wise_routing_weights_save_hooks[
|
|
86
|
+
(i, linear_module_name)
|
|
87
|
+
] = hook
|
|
88
|
+
self._layer_wise_routing_weights_save_hook_handles[
|
|
89
|
+
(i, linear_module_name)
|
|
90
|
+
] = linear_module.gate.register_forward_hook(hook)
|
|
91
|
+
|
|
92
|
+
def on_task_evaluation_end(self):
|
|
93
|
+
super().on_task_evaluation_end()
|
|
94
|
+
if self.layer_wise_routing_weights_save_path is not None:
|
|
95
|
+
# remove hooks for saving layer-wise routing weights
|
|
96
|
+
for (
|
|
97
|
+
key,
|
|
98
|
+
handle,
|
|
99
|
+
) in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
100
|
+
self._layer_wise_routing_weights_save_hooks[key].save_routing_weights()
|
|
101
|
+
self._layer_wise_routing_weights_save_hook_handles.pop(key)
|
|
102
|
+
handle.remove()
|
|
@@ -15,36 +15,7 @@ from fusion_bench.models.sparse_we_moe import (
|
|
|
15
15
|
)
|
|
16
16
|
|
|
17
17
|
from .taskpool import CLIPVisionModelTaskPool
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class LayerWiseRoutingWeightSaver:
|
|
21
|
-
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
22
|
-
self.save_path = save_path
|
|
23
|
-
self.max_num = max_num
|
|
24
|
-
self.routing_weights = []
|
|
25
|
-
|
|
26
|
-
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
27
|
-
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
28
|
-
# (batch_size, num_tokens, num_experts)
|
|
29
|
-
routing_weights = output.detach().cpu()
|
|
30
|
-
if self.max_num is not None and self.max_num > 0:
|
|
31
|
-
if len(self.routing_weights) > self.max_num:
|
|
32
|
-
return
|
|
33
|
-
elif routing_weights.size(0) + len(self.routing_weights) > self.max_num:
|
|
34
|
-
self.routing_weights.append(
|
|
35
|
-
routing_weights[: self.max_num - len(self.routing_weights)]
|
|
36
|
-
)
|
|
37
|
-
else:
|
|
38
|
-
self.routing_weights.append(routing_weights)
|
|
39
|
-
else:
|
|
40
|
-
self.routing_weights.append(routing_weights)
|
|
41
|
-
|
|
42
|
-
def save_routing_weights(self):
|
|
43
|
-
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
44
|
-
if self.save_path is not None:
|
|
45
|
-
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
46
|
-
print(f"Saving routing weights to {self.save_path}")
|
|
47
|
-
torch.save(routing_weights, self.save_path)
|
|
18
|
+
from .utils.routing_analysis_utils import LayerWiseRoutingWeightSaver
|
|
48
19
|
|
|
49
20
|
|
|
50
21
|
class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
@@ -117,4 +88,5 @@ class SparseWEMoECLIPVisionModelTaskPool(CLIPVisionModelTaskPool):
|
|
|
117
88
|
# remove hooks for saving layer-wise routing weights
|
|
118
89
|
for i, handle in self._layer_wise_routing_weights_save_hook_handles.items():
|
|
119
90
|
self._layer_wise_routing_weights_save_hooks[i].save_routing_weights()
|
|
91
|
+
self._layer_wise_routing_weights_save_hook_handles.pop(i)
|
|
120
92
|
handle.remove()
|
|
@@ -32,8 +32,7 @@ from fusion_bench.mixins import LightningFabricMixin
|
|
|
32
32
|
from fusion_bench.models.hf_clip import HFCLIPClassifier
|
|
33
33
|
from fusion_bench.taskpool import BaseTaskPool
|
|
34
34
|
from fusion_bench.tasks.clip_classification import get_classnames_and_templates
|
|
35
|
-
from fusion_bench.utils import instantiate
|
|
36
|
-
from fusion_bench.utils.parameters import count_parameters
|
|
35
|
+
from fusion_bench.utils import count_parameters, instantiate
|
|
37
36
|
|
|
38
37
|
if TYPE_CHECKING:
|
|
39
38
|
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
File without changes
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _number_of_samples(routing_weights: List[Tensor]):
|
|
9
|
+
count = 0
|
|
10
|
+
for routing_weight in routing_weights:
|
|
11
|
+
count += routing_weight.size(0)
|
|
12
|
+
return count
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class LayerWiseRoutingWeightSaver:
|
|
16
|
+
"""
|
|
17
|
+
A hook for saving layer-wise routing weights.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
save_path: Path
|
|
21
|
+
"The path to save the layer-wise routing weights."
|
|
22
|
+
max_num: Optional[int]
|
|
23
|
+
"The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved."
|
|
24
|
+
routing_weights: List[Tensor]
|
|
25
|
+
"The list of layer-wise routing weights."
|
|
26
|
+
|
|
27
|
+
def __init__(self, save_path: Path, max_num: Optional[int] = None):
|
|
28
|
+
"""
|
|
29
|
+
Args:
|
|
30
|
+
save_path (Path): The path to save the layer-wise routing weights.
|
|
31
|
+
max_num (Optional[int]): The maximum number of layer-wise routing weights to save. If None, all routing weights will be saved.
|
|
32
|
+
"""
|
|
33
|
+
self.save_path = save_path
|
|
34
|
+
self.max_num = max_num
|
|
35
|
+
self.routing_weights = []
|
|
36
|
+
|
|
37
|
+
def __call__(self, module, input: Tuple[Tensor], output: Tensor):
|
|
38
|
+
assert isinstance(output, Tensor), "Output is expected to be a Tensor"
|
|
39
|
+
# (batch_size, num_tokens, num_experts)
|
|
40
|
+
routing_weights = output.detach().cpu()
|
|
41
|
+
if self.max_num is not None and self.max_num > 0:
|
|
42
|
+
if _number_of_samples(self.routing_weights) > self.max_num:
|
|
43
|
+
return
|
|
44
|
+
elif (
|
|
45
|
+
routing_weights.size(0) + _number_of_samples(self.routing_weights)
|
|
46
|
+
> self.max_num
|
|
47
|
+
):
|
|
48
|
+
self.routing_weights.append(
|
|
49
|
+
routing_weights[
|
|
50
|
+
: self.max_num - _number_of_samples(self.routing_weights)
|
|
51
|
+
]
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
self.routing_weights.append(routing_weights)
|
|
55
|
+
else:
|
|
56
|
+
self.routing_weights.append(routing_weights)
|
|
57
|
+
|
|
58
|
+
def save_routing_weights(self):
|
|
59
|
+
routing_weights = torch.cat(self.routing_weights, dim=0)
|
|
60
|
+
if self.save_path is not None:
|
|
61
|
+
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
62
|
+
print(
|
|
63
|
+
f"Saving routing weights to {self.save_path}. Size: {routing_weights.size()}"
|
|
64
|
+
)
|
|
65
|
+
torch.save(routing_weights, self.save_path)
|
|
@@ -139,11 +139,40 @@ class GPT2TextClassificationTaskPool(BaseTaskPool, LightningFabricMixin):
|
|
|
139
139
|
return dataloader
|
|
140
140
|
|
|
141
141
|
@override
|
|
142
|
-
def evaluate(self, model: GPT2Model):
|
|
142
|
+
def evaluate(self, model: GPT2Model, name: str = None):
|
|
143
|
+
"""Evaluate the model on the test datasets.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
model (GPT2Model): The model to evaluate.
|
|
147
|
+
name (str, optional): The name of the model. Defaults to None. This is used to identify the model in the report.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
dict: A dictionary containing the evaluation results for each task.
|
|
151
|
+
"""
|
|
143
152
|
report = {}
|
|
153
|
+
if name is not None:
|
|
154
|
+
report["name"] = name
|
|
144
155
|
for task_name in (pbar := tqdm(self._test_datasets, desc="Evaluating tasks")):
|
|
145
156
|
pbar.set_description(f"Evaluating task {task_name}")
|
|
146
157
|
dataloader = self.get_test_dataloader(task_name)
|
|
147
158
|
result = self.evaluate_single_task(task_name, model, dataloader)
|
|
148
159
|
report[task_name] = result
|
|
160
|
+
|
|
161
|
+
# calculate the average accuracy and loss
|
|
162
|
+
if "average" not in report:
|
|
163
|
+
report["average"] = {}
|
|
164
|
+
accuracies = [
|
|
165
|
+
value["accuracy"]
|
|
166
|
+
for key, value in report.items()
|
|
167
|
+
if isinstance(value, dict) and "accuracy" in value
|
|
168
|
+
]
|
|
169
|
+
if len(accuracies) > 0:
|
|
170
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
171
|
+
report["average"]["accuracy"] = average_accuracy
|
|
172
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
173
|
+
if len(losses) > 0:
|
|
174
|
+
average_loss = sum(losses) / len(losses)
|
|
175
|
+
report["average"]["loss"] = average_loss
|
|
176
|
+
|
|
177
|
+
log.info(f"Evaluation Result: {report}")
|
|
149
178
|
return report
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .openclip_taskpool import OpenCLIPVisionModelTaskPool
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
|
6
|
+
|
|
7
|
+
import lightning.fabric
|
|
8
|
+
import open_clip
|
|
9
|
+
import torch
|
|
10
|
+
from omegaconf import DictConfig
|
|
11
|
+
from torch.nn import functional as F
|
|
12
|
+
from torch.utils.data import DataLoader, Dataset
|
|
13
|
+
from torchmetrics import Accuracy, MeanMetric
|
|
14
|
+
from torchmetrics.classification.accuracy import MulticlassAccuracy
|
|
15
|
+
from tqdm.auto import tqdm
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseTaskPool
|
|
18
|
+
from fusion_bench.dataset.clip_dataset import CLIPDataset
|
|
19
|
+
from fusion_bench.mixins import LightningFabricMixin
|
|
20
|
+
from fusion_bench.modelpool.openclip_vision.modelpool import load_classifier_head
|
|
21
|
+
from fusion_bench.models.open_clip import (
|
|
22
|
+
ClassificationHead,
|
|
23
|
+
ImageClassifier,
|
|
24
|
+
ImageEncoder,
|
|
25
|
+
)
|
|
26
|
+
from fusion_bench.models.open_clip.variables_and_paths import OPENCLIP_CACHEDIR
|
|
27
|
+
from fusion_bench.utils import count_parameters, instantiate
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from fusion_bench.modelpool import OpenCLIPVisionModelPool
|
|
31
|
+
from fusion_bench.programs import FabricModelFusionProgram
|
|
32
|
+
|
|
33
|
+
log = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OpenCLIPVisionModelTaskPool(
|
|
37
|
+
BaseTaskPool,
|
|
38
|
+
LightningFabricMixin,
|
|
39
|
+
):
|
|
40
|
+
_is_setup = False
|
|
41
|
+
|
|
42
|
+
_program: "FabricModelFusionProgram"
|
|
43
|
+
|
|
44
|
+
processor: Optional[Callable] = None
|
|
45
|
+
test_datasets: Dict[str, CLIPDataset]
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
test_datasets: Union[DictConfig, Dict[str, Dataset]],
|
|
50
|
+
classification_heads: Union[DictConfig, Dict[str, ClassificationHead]],
|
|
51
|
+
dataloader_kwargs: DictConfig,
|
|
52
|
+
model_name: Optional[str] = None,
|
|
53
|
+
fast_dev_run: bool = False,
|
|
54
|
+
**kwargs,
|
|
55
|
+
):
|
|
56
|
+
self._test_datasets = test_datasets
|
|
57
|
+
self._classifier_heads = classification_heads
|
|
58
|
+
self._dataloader_kwargs = dataloader_kwargs
|
|
59
|
+
self._model_name = model_name
|
|
60
|
+
self.fast_dev_run = fast_dev_run
|
|
61
|
+
super().__init__(**kwargs)
|
|
62
|
+
|
|
63
|
+
def setup(self):
|
|
64
|
+
# setup the processor
|
|
65
|
+
if self._program is not None and self._program.modelpool is not None:
|
|
66
|
+
modelpool: "OpenCLIPVisionModelPool" = self._program.modelpool
|
|
67
|
+
self.processor = modelpool.test_processor
|
|
68
|
+
elif self._model_name is not None:
|
|
69
|
+
_, _, self.processor = open_clip.create_model_and_transforms(
|
|
70
|
+
self._model_name,
|
|
71
|
+
pretrained="openai",
|
|
72
|
+
cache_dir=OPENCLIP_CACHEDIR,
|
|
73
|
+
)
|
|
74
|
+
else:
|
|
75
|
+
raise ValueError("Modelpool or model_name is not set")
|
|
76
|
+
|
|
77
|
+
# setup the test datasets
|
|
78
|
+
self.test_datasets = {
|
|
79
|
+
name: instantiate(dataset) if isinstance(dataset, DictConfig) else dataset
|
|
80
|
+
for name, dataset in self._test_datasets.items()
|
|
81
|
+
}
|
|
82
|
+
self.test_datasets = {
|
|
83
|
+
name: CLIPDataset(dataset, self.processor)
|
|
84
|
+
for name, dataset in self.test_datasets.items()
|
|
85
|
+
}
|
|
86
|
+
self.test_dataloaders = {
|
|
87
|
+
name: self.fabric.setup_dataloaders(
|
|
88
|
+
DataLoader(dataset, **self._dataloader_kwargs)
|
|
89
|
+
)
|
|
90
|
+
for name, dataset in self.test_datasets.items()
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
# setup classifier heads
|
|
94
|
+
self.classifier_heads = {
|
|
95
|
+
name: load_classifier_head(head).to(self.fabric.device)
|
|
96
|
+
for name, head in self._classifier_heads.items()
|
|
97
|
+
}
|
|
98
|
+
self._is_setup = True
|
|
99
|
+
|
|
100
|
+
@torch.no_grad()
|
|
101
|
+
def _evaluate(
|
|
102
|
+
self,
|
|
103
|
+
classifier: ImageClassifier,
|
|
104
|
+
test_loader: DataLoader,
|
|
105
|
+
num_classes: int,
|
|
106
|
+
task_name: str,
|
|
107
|
+
):
|
|
108
|
+
accuracy: MulticlassAccuracy = Accuracy(
|
|
109
|
+
task="multiclass", num_classes=num_classes
|
|
110
|
+
)
|
|
111
|
+
classifier.eval()
|
|
112
|
+
loss_metric = MeanMetric()
|
|
113
|
+
# if fast_dev_run is set, we only evaluate on a batch of the data
|
|
114
|
+
if self.fast_dev_run:
|
|
115
|
+
log.info("Running under fast_dev_run mode, evaluating on a single batch.")
|
|
116
|
+
test_loader = itertools.islice(test_loader, 1)
|
|
117
|
+
else:
|
|
118
|
+
test_loader = test_loader
|
|
119
|
+
|
|
120
|
+
pbar = tqdm(
|
|
121
|
+
test_loader,
|
|
122
|
+
desc=f"Evaluating {task_name}",
|
|
123
|
+
leave=False,
|
|
124
|
+
dynamic_ncols=True,
|
|
125
|
+
)
|
|
126
|
+
for batch in pbar:
|
|
127
|
+
inputs, targets = batch
|
|
128
|
+
logits = classifier(inputs)
|
|
129
|
+
loss = F.cross_entropy(logits, targets)
|
|
130
|
+
loss_metric.update(loss.detach().cpu())
|
|
131
|
+
acc = accuracy(logits.detach().cpu(), targets.detach().cpu())
|
|
132
|
+
pbar.set_postfix(
|
|
133
|
+
{
|
|
134
|
+
"accuracy": accuracy.compute().item(),
|
|
135
|
+
"loss": loss_metric.compute().item(),
|
|
136
|
+
}
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
acc = accuracy.compute().item()
|
|
140
|
+
loss = loss_metric.compute().item()
|
|
141
|
+
results = {"accuracy": acc, "loss": loss}
|
|
142
|
+
return results
|
|
143
|
+
|
|
144
|
+
def evaluate(self, model: ImageEncoder, **kwargs):
|
|
145
|
+
if not self._is_setup:
|
|
146
|
+
self.setup()
|
|
147
|
+
|
|
148
|
+
report = {}
|
|
149
|
+
# collect basic model information
|
|
150
|
+
training_params, all_params = count_parameters(model)
|
|
151
|
+
report["model_info"] = {
|
|
152
|
+
"trainable_params": training_params,
|
|
153
|
+
"all_params": all_params,
|
|
154
|
+
"trainable_percentage": training_params / all_params,
|
|
155
|
+
}
|
|
156
|
+
|
|
157
|
+
if not lightning.fabric.is_wrapped(model):
|
|
158
|
+
model = self.fabric.setup_module(model)
|
|
159
|
+
|
|
160
|
+
pbar = tqdm(
|
|
161
|
+
self.test_dataloaders.items(),
|
|
162
|
+
desc="Evaluating tasks",
|
|
163
|
+
total=len(self.test_dataloaders),
|
|
164
|
+
)
|
|
165
|
+
for task_name, test_dataloader in pbar:
|
|
166
|
+
classifier = ImageClassifier(model, self.classifier_heads[task_name])
|
|
167
|
+
num_classes = self.classifier_heads[task_name].weight.size(0)
|
|
168
|
+
result = self._evaluate(
|
|
169
|
+
classifier,
|
|
170
|
+
test_dataloader,
|
|
171
|
+
num_classes=num_classes,
|
|
172
|
+
task_name=task_name,
|
|
173
|
+
)
|
|
174
|
+
report[task_name] = result
|
|
175
|
+
|
|
176
|
+
# calculate the average accuracy and loss
|
|
177
|
+
if "average" not in report:
|
|
178
|
+
report["average"] = {}
|
|
179
|
+
accuracies = [
|
|
180
|
+
value["accuracy"]
|
|
181
|
+
for key, value in report.items()
|
|
182
|
+
if "accuracy" in value
|
|
183
|
+
]
|
|
184
|
+
if len(accuracies) > 0:
|
|
185
|
+
average_accuracy = sum(accuracies) / len(accuracies)
|
|
186
|
+
report["average"]["accuracy"] = average_accuracy
|
|
187
|
+
losses = [value["loss"] for key, value in report.items() if "loss" in value]
|
|
188
|
+
if len(losses) > 0:
|
|
189
|
+
average_loss = sum(losses) / len(losses)
|
|
190
|
+
report["average"]["loss"] = average_loss
|
|
191
|
+
|
|
192
|
+
log.info(f"Evaluation Result: {report}")
|
|
193
|
+
if self.fabric.is_global_zero and len(self.fabric._loggers) > 0:
|
|
194
|
+
with open(os.path.join(self.log_dir, "report.json"), "w") as fp:
|
|
195
|
+
json.dump(report, fp)
|
|
196
|
+
return report
|
fusion_bench/utils/data.py
CHANGED
|
@@ -9,6 +9,18 @@ from torch.utils.data import DataLoader, Dataset
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class InfiniteDataLoader:
|
|
12
|
+
"""
|
|
13
|
+
A wrapper class for DataLoader to create an infinite data loader.
|
|
14
|
+
This is useful in case we are only interested in the number of steps and not the number of epochs.
|
|
15
|
+
|
|
16
|
+
This class wraps a DataLoader and provides an iterator that resets
|
|
17
|
+
when the end of the dataset is reached, creating an infinite loop.
|
|
18
|
+
|
|
19
|
+
Attributes:
|
|
20
|
+
data_loader (DataLoader): The DataLoader to wrap.
|
|
21
|
+
data_iter (iterator): An iterator over the DataLoader.
|
|
22
|
+
"""
|
|
23
|
+
|
|
12
24
|
def __init__(self, data_loader: DataLoader):
|
|
13
25
|
self.data_loader = data_loader
|
|
14
26
|
self.data_iter = iter(data_loader)
|
fusion_bench/utils/devices.py
CHANGED
|
@@ -229,3 +229,17 @@ def cleanup_cuda():
|
|
|
229
229
|
gc.collect()
|
|
230
230
|
torch.cuda.empty_cache()
|
|
231
231
|
torch.cuda.reset_peak_memory_stats()
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def print_memory_usage(print_fn=print):
|
|
235
|
+
"""
|
|
236
|
+
Print the current GPU memory usage.
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
str: A string containing the allocated and cached memory in MB.
|
|
240
|
+
"""
|
|
241
|
+
allocated = torch.cuda.memory_allocated() / 1024**2 # 转换为 MB
|
|
242
|
+
cached = torch.cuda.memory_reserved() / 1024**2 # 转换为 MB
|
|
243
|
+
print_str = f"Allocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
244
|
+
print_fn(print_str)
|
|
245
|
+
return print_str
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# Modified from Hydra
|
|
3
3
|
import copy
|
|
4
4
|
import functools
|
|
5
|
+
from contextlib import contextmanager
|
|
5
6
|
from enum import Enum
|
|
6
7
|
from textwrap import dedent
|
|
7
8
|
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
|
|
@@ -30,6 +31,17 @@ Function to be used for printing function calls.
|
|
|
30
31
|
CATCH_EXCEPTION = True
|
|
31
32
|
|
|
32
33
|
|
|
34
|
+
@contextmanager
|
|
35
|
+
def set_print_function_call(value: bool):
|
|
36
|
+
global PRINT_FUNCTION_CALL
|
|
37
|
+
old_value = PRINT_FUNCTION_CALL
|
|
38
|
+
PRINT_FUNCTION_CALL = value
|
|
39
|
+
try:
|
|
40
|
+
yield
|
|
41
|
+
finally:
|
|
42
|
+
PRINT_FUNCTION_CALL = old_value
|
|
43
|
+
|
|
44
|
+
|
|
33
45
|
def is_instantiable(config: Union[DictConfig, Any]) -> bool:
|
|
34
46
|
if OmegaConf.is_dict(config):
|
|
35
47
|
return "_target_" in config
|
fusion_bench/utils/misc.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
from typing import Iterable
|
|
1
|
+
from typing import Iterable, List
|
|
2
2
|
|
|
3
|
-
__all__ = ["first", "has_length"]
|
|
3
|
+
__all__ = ["first", "has_length", "join_list"]
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def first(iterable: Iterable):
|
|
@@ -16,3 +16,10 @@ def has_length(dataset):
|
|
|
16
16
|
except TypeError:
|
|
17
17
|
# TypeError: len() of unsized object
|
|
18
18
|
return False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def join_list(list_of_list: List[List]):
|
|
22
|
+
ans = []
|
|
23
|
+
for item in list_of_list:
|
|
24
|
+
ans.extend(item)
|
|
25
|
+
return ans
|
fusion_bench/utils/packages.py
CHANGED
|
@@ -82,3 +82,17 @@ def import_object(abs_obj_name: str):
|
|
|
82
82
|
module_name, obj_name = abs_obj_name.rsplit(".", 1)
|
|
83
83
|
module = importlib.import_module(module_name)
|
|
84
84
|
return getattr(module, obj_name)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def compare_versions(v1, v2):
|
|
88
|
+
"""Compare two version strings.
|
|
89
|
+
Returns -1 if v1 < v2, 0 if v1 == v2, 1 if v1 > v2"""
|
|
90
|
+
|
|
91
|
+
v1 = version.parse(v1)
|
|
92
|
+
v2 = version.parse(v2)
|
|
93
|
+
if v1 < v2:
|
|
94
|
+
return -1
|
|
95
|
+
elif v1 > v2:
|
|
96
|
+
return 1
|
|
97
|
+
else:
|
|
98
|
+
return 0
|
fusion_bench/utils/parameters.py
CHANGED
|
@@ -252,7 +252,7 @@ def print_parameters(
|
|
|
252
252
|
|
|
253
253
|
|
|
254
254
|
def check_parameters_all_equal(
|
|
255
|
-
list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]]
|
|
255
|
+
list_of_param_names: List[Union[StateDictType, nn.Module, List[str]]],
|
|
256
256
|
) -> None:
|
|
257
257
|
"""
|
|
258
258
|
Checks if all models have the same parameters.
|