fusion-bench 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +3 -1
- fusion_bench/compat/taskpool/flan_t5_glue_text_generation.py +4 -1
- fusion_bench/constants/clip_vision.py +22 -0
- fusion_bench/dataset/clip_dataset.py +10 -2
- fusion_bench/dataset/gsm8k.py +2 -2
- fusion_bench/method/__init__.py +12 -2
- fusion_bench/method/adamerging/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/adamerging/clip_task_wise_adamerging.py +1 -29
- fusion_bench/method/doge_ta/__init__.py +2 -0
- fusion_bench/method/{DOGE_TA → doge_ta}/clip_layer_wise_adamerging.py +1 -1
- fusion_bench/method/{DOGE_TA/DOGE_TA.py → doge_ta/doge_ta.py} +1 -1
- fusion_bench/method/fisher_merging/fisher_merging.py +29 -17
- fusion_bench/method/gossip/__init__.py +3 -0
- fusion_bench/method/gossip/clip_layer_wise_gossip.py +43 -0
- fusion_bench/method/gossip/clip_task_wise_gossip.py +190 -0
- fusion_bench/method/gossip/entropy_loss.py +25 -0
- fusion_bench/method/gossip/flan_t5_layer_wise_gossip.py +388 -0
- fusion_bench/method/gossip/layer_wise_gossip.py +434 -0
- fusion_bench/method/gossip/min_norm_solvers.py +227 -0
- fusion_bench/method/gossip/task_wise_gossip.py +265 -0
- fusion_bench/method/gossip/utils.py +74 -0
- fusion_bench/method/isotropic_merging/__init__.py +1 -1
- fusion_bench/method/opcm/opcm.py +102 -84
- fusion_bench/method/opcm/task_arithmetic.py +35 -21
- fusion_bench/method/opcm/ties_merging.py +71 -52
- fusion_bench/method/pwe_moe/module.py +1 -1
- fusion_bench/method/pwe_moe/openclip_pwe_moe.py +476 -0
- fusion_bench/method/regmean/regmean.py +25 -17
- fusion_bench/method/smile_upscaling/__init__.py +1 -1
- fusion_bench/method/smile_upscaling/smile_upscaling.py +13 -10
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +7 -0
- fusion_bench/method/task_arithmetic/task_arithmetic.py +8 -6
- fusion_bench/method/ties_merging/ties_merging.py +36 -31
- fusion_bench/method/we_moe/we_moe.py +14 -15
- fusion_bench/mixins/__init__.py +6 -3
- fusion_bench/mixins/hydra_config.py +49 -0
- fusion_bench/mixins/openclip_classification.py +11 -0
- fusion_bench/mixins/simple_profiler.py +4 -2
- fusion_bench/modelpool/__init__.py +3 -1
- fusion_bench/modelpool/base_pool.py +2 -2
- fusion_bench/modelpool/openclip_vision/__init__.py +1 -0
- fusion_bench/modelpool/openclip_vision/modelpool.py +255 -0
- fusion_bench/models/open_clip/__init__.py +6 -0
- fusion_bench/models/open_clip/modeling.py +176 -0
- fusion_bench/models/open_clip/utils.py +311 -0
- fusion_bench/models/open_clip/variables_and_paths.py +56 -0
- fusion_bench/models/parameter_dict.py +54 -13
- fusion_bench/models/wrappers/layer_wise_fusion.py +1 -46
- fusion_bench/models/wrappers/layer_wise_fusion_doge_ta.py +4 -119
- fusion_bench/scripts/nyuv2_mtl_train.py +1 -1
- fusion_bench/taskpool/__init__.py +5 -3
- fusion_bench/taskpool/clip_vision/__init__.py +1 -0
- fusion_bench/taskpool/clip_vision/clip_rankone_moe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/clip_smile_taskpool.py +102 -0
- fusion_bench/taskpool/clip_vision/clip_sparse_wemoe_taskpool.py +2 -30
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -2
- fusion_bench/taskpool/clip_vision/utils/__init__.py +0 -0
- fusion_bench/taskpool/clip_vision/utils/routing_analysis_utils.py +65 -0
- fusion_bench/taskpool/gpt2_text_classification.py +30 -1
- fusion_bench/taskpool/openclip_vision/__init__.py +1 -0
- fusion_bench/taskpool/openclip_vision/openclip_taskpool.py +196 -0
- fusion_bench/utils/data.py +12 -0
- fusion_bench/utils/devices.py +14 -0
- fusion_bench/utils/instantiate.py +12 -0
- fusion_bench/utils/misc.py +9 -2
- fusion_bench/utils/packages.py +14 -0
- fusion_bench/utils/parameters.py +1 -1
- fusion_bench/utils/tensorboard.py +1 -1
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/METADATA +15 -2
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/RECORD +198 -158
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/WHEEL +1 -1
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -2
- fusion_bench_config/dataset/image_classification/test/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/emnist_letters.yaml +0 -1
- fusion_bench_config/dataset/image_classification/test/fashion_mnist.yaml +1 -1
- fusion_bench_config/dataset/image_classification/train/TALL20.yaml +0 -1
- fusion_bench_config/dataset/image_classification/train/fashion_mnist.yaml +1 -1
- fusion_bench_config/fabric/auto.yaml +0 -1
- fusion_bench_config/fabric/llama_ddp.yaml +0 -1
- fusion_bench_config/fabric/llama_fsdp.yaml +0 -1
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric/strategy/deepspeed.yaml +0 -1
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +0 -1
- fusion_bench_config/fabric_model_fusion.yaml +0 -1
- fusion_bench_config/llama_full_finetune.yaml +0 -2
- fusion_bench_config/llama_model_fusion.yaml +0 -2
- fusion_bench_config/method/ada_svd/clip_vision.yaml +0 -1
- fusion_bench_config/method/adamerging/layer_wise_flan_t5.yaml +0 -5
- fusion_bench_config/method/adamerging/layer_wise_gpt2.yaml +0 -5
- fusion_bench_config/method/adamerging/llama_sft.yaml +0 -2
- fusion_bench_config/method/adamerging.yaml +2 -2
- fusion_bench_config/method/analysis/task_vector_cos_similarity.yaml +0 -1
- fusion_bench_config/method/analysis/task_vector_violin_plot.yaml +0 -1
- fusion_bench_config/method/classification/clip_continual_finetune.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_layer_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_concrete_task_wise_adamerging.yaml +0 -1
- fusion_bench_config/method/concrete_subspace/clip_post_defense_AWM.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_post_defense_SAU.yaml +1 -12
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_layer_wise_adamerging.yaml +1 -10
- fusion_bench_config/method/concrete_subspace/clip_safe_concrete_task_arithmetic.yaml +1 -14
- fusion_bench_config/method/dare/simple_average.yaml +0 -1
- fusion_bench_config/method/dare/task_arithmetic.yaml +0 -1
- fusion_bench_config/method/dare/ties_merging.yaml +0 -2
- fusion_bench_config/method/dawe/dawe_for_clip.yaml +0 -3
- fusion_bench_config/method/{DOGE_TA/DOGE_TA.yaml → doge_ta/doge_ta.yaml} +1 -1
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +1 -1
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +0 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +0 -1
- fusion_bench_config/method/gossip/layer_wise_clip.yaml +30 -0
- fusion_bench_config/method/gossip/layer_wise_flan_t5.yaml +25 -0
- fusion_bench_config/method/isotropic_merging/iso_c.yaml +0 -1
- fusion_bench_config/method/isotropic_merging/iso_cts.yaml +0 -1
- fusion_bench_config/method/linear/linear_interpolation.yaml +0 -1
- fusion_bench_config/method/linear/llama_expo.yaml +0 -3
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +0 -5
- fusion_bench_config/method/linear/weighted_average.yaml +0 -1
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +0 -1
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +0 -4
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +0 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +0 -6
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +1 -2
- fusion_bench_config/method/model_recombination.yaml +0 -1
- fusion_bench_config/method/opcm/opcm.yaml +0 -1
- fusion_bench_config/method/opcm/task_arithmetic.yaml +0 -2
- fusion_bench_config/method/opcm/ties_merging.yaml +0 -2
- fusion_bench_config/method/opcm/weight_average.yaml +0 -1
- fusion_bench_config/method/pwe_moe/epo_for_openclip.yaml +30 -0
- fusion_bench_config/method/pwe_moe/ls_for_openclip.yaml +30 -0
- fusion_bench_config/method/{pwe_moe_ls_for_clip.yaml → pwe_moe/pwe_moe_ls_for_clip.yaml} +7 -6
- fusion_bench_config/method/rankone_moe/rankone_moe.yaml +1 -3
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +0 -1
- fusion_bench_config/method/slerp/slerp.yaml +0 -2
- fusion_bench_config/method/sparselo_pruning/llama_iterative_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_pcp_sparselo.yaml +1 -1
- fusion_bench_config/method/sparselo_pruning/llama_sparselo.yaml +1 -1
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +1 -2
- fusion_bench_config/method/task_arithmetic.yaml +1 -1
- fusion_bench_config/method/task_singular_vector/TaskSingularVectorMerging.yaml +0 -1
- fusion_bench_config/method/ties_merging.yaml +1 -1
- fusion_bench_config/method/trust_region/clip_task_arithmetic.yaml +0 -1
- fusion_bench_config/method/wemoe/sparse_weight_ensembling_moe.yaml +0 -8
- fusion_bench_config/model/clip-vit/clip-vit-base-patch16_cifar10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford-iiit-pet.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_oxford_flowers102.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_pcam.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_rendered-sst2.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_resisc45.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stanford-cars.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_stl10.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_sun397.yaml +1 -1
- fusion_bench_config/model/clip-vit/clip-vit-large-patch14_svhn.yaml +1 -1
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_TA8_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch16_individual_lora.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_TA8_control_task.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_individual.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_single_task_projection.yaml +0 -3
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-base-patch32_two_tasks_control_task.yaml +0 -4
- fusion_bench_config/modelpool/CLIPVisionModelPool/clip-vit-large-patch14_individual.yaml +0 -3
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_for_causallm.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +0 -4
- fusion_bench_config/modelpool/CausalLMPool/simle_mixtral_exp_v4.yaml +0 -1
- fusion_bench_config/modelpool/CausalLMPool/single_llama_model.yaml +0 -3
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/README.md +90 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-16_TA8.yaml +27 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA8.yaml +45 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_cars_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_cars.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_TA_sun397_dtd.yaml +23 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-B-32_individual.yaml +7 -0
- fusion_bench_config/modelpool/OpenCLIPVisionModelPool/ViT-L-14_TA8.yaml +26 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16.yaml +0 -2
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_lora16_tta.yaml +8 -10
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_glue_tta.yaml +66 -0
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-base_individual.yaml +0 -1
- fusion_bench_config/modelpool/Seq2SeqLMPool/flan-t5-large_glue_lora16.yaml +0 -3
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +0 -4
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +0 -3
- fusion_bench_config/modelpool/gpt-2_glue.yaml +0 -3
- fusion_bench_config/nyuv2_config.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/_template.yaml +0 -3
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip-vit-classification_TA8_B16.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_rankone_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/CLIPVisionModelTaskPool/clip_sparse_wemoe_clip-vit-classification_TA8.yaml +0 -2
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-16_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-B-32_TA8.yaml +24 -0
- fusion_bench_config/taskpool/OpenCLIPVisionModelTaskPool/ViT-L-14_TA8.yaml +24 -0
- fusion_bench_config/taskpool/gpt-2_glue.yaml +0 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +0 -4
- fusion_bench/method/DOGE_TA/__init__.py +0 -2
- /fusion_bench/method/{DOGE_TA → doge_ta}/layer_wise_adamerging.py +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info/licenses}/LICENSE +0 -0
- {fusion_bench-0.2.11.dist-info → fusion_bench-0.2.13.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import pickle
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Callable, Optional, Union, cast
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from datasets import load_dataset
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
from torch import nn
|
|
10
|
+
|
|
11
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
12
|
+
from fusion_bench.models.open_clip import ClassificationHead, ImageEncoder
|
|
13
|
+
from fusion_bench.utils import instantiate
|
|
14
|
+
from fusion_bench.utils.expr import is_expr_match
|
|
15
|
+
from fusion_bench.utils.packages import _get_package_version, compare_versions
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
# Add flag to track if warning has been shown
|
|
20
|
+
_openclip_version_warning_shown = False
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _check_and_redirect_open_clip_modeling():
|
|
24
|
+
global _openclip_version_warning_shown
|
|
25
|
+
if compare_versions(_get_package_version("open-clip-torch").__str__(), "2.0.2") > 0:
|
|
26
|
+
if not _openclip_version_warning_shown:
|
|
27
|
+
log.warning(
|
|
28
|
+
"OpenCLIP version is greater than 2.0.2. This may cause issues with the modelpool."
|
|
29
|
+
)
|
|
30
|
+
_openclip_version_warning_shown = True
|
|
31
|
+
import open_clip.model
|
|
32
|
+
import open_clip.transformer
|
|
33
|
+
|
|
34
|
+
if not hasattr(open_clip.model, "VisualTransformer"):
|
|
35
|
+
open_clip.model.VisualTransformer = open_clip.model.VisionTransformer
|
|
36
|
+
if not hasattr(open_clip.model, "Transformer"):
|
|
37
|
+
open_clip.model.Transformer = open_clip.transformer.Transformer
|
|
38
|
+
if not hasattr(open_clip.model, "ResidualAttentionBlock"):
|
|
39
|
+
open_clip.model.ResidualAttentionBlock = (
|
|
40
|
+
open_clip.transformer.ResidualAttentionBlock
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
import src
|
|
45
|
+
import src.modeling
|
|
46
|
+
except ImportError:
|
|
47
|
+
if "src" not in sys.modules:
|
|
48
|
+
# redirect the import of `src` to `fusion_bench.models.open_clip`
|
|
49
|
+
import fusion_bench.models.open_clip as open_clip
|
|
50
|
+
|
|
51
|
+
sys.modules["src"] = open_clip
|
|
52
|
+
log.warning(
|
|
53
|
+
"`src` is not imported."
|
|
54
|
+
"Redirecting the import to `fusion_bench.models.open_clip`"
|
|
55
|
+
)
|
|
56
|
+
if "src.modeling" not in sys.modules:
|
|
57
|
+
# redirect the import of `src.modeling` to `fusion_bench.models.open_clip.modeling`
|
|
58
|
+
import fusion_bench.models.open_clip.modeling as open_clip_modeling
|
|
59
|
+
|
|
60
|
+
sys.modules["src.modeling"] = open_clip_modeling
|
|
61
|
+
log.warning(
|
|
62
|
+
"`src.modeling` is not imported."
|
|
63
|
+
"Redirecting the import to `fusion_bench.models.open_clip.modeling`"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def load_classifier_head(model_config: Union[str, DictConfig], *args, **kwargs):
|
|
68
|
+
if isinstance(model_config, str):
|
|
69
|
+
_check_and_redirect_open_clip_modeling()
|
|
70
|
+
log.info(f"Loading `ClassificationHead` from {model_config}")
|
|
71
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
72
|
+
head = torch.load(model_config, weights_only=weights_only, *args, **kwargs)
|
|
73
|
+
elif isinstance(model_config, nn.Module):
|
|
74
|
+
log.info(f"Returning existing model: {model_config}")
|
|
75
|
+
head = model_config
|
|
76
|
+
else:
|
|
77
|
+
head = instantiate(model_config, *args, **kwargs)
|
|
78
|
+
head = cast(ClassificationHead, head)
|
|
79
|
+
return head
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class OpenCLIPVisionModelPool(BaseModelPool):
|
|
83
|
+
"""
|
|
84
|
+
A model pool for managing OpenCLIP Vision models (models from task vector paper).
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
_train_processor = None
|
|
88
|
+
_test_processor = None
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
models: DictConfig,
|
|
93
|
+
classification_heads: Optional[DictConfig] = None,
|
|
94
|
+
**kwargs,
|
|
95
|
+
):
|
|
96
|
+
super().__init__(models, **kwargs)
|
|
97
|
+
self._classification_heads = classification_heads
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def train_processor(self):
|
|
101
|
+
if self._train_processor is None:
|
|
102
|
+
encoder: ImageEncoder = self.load_pretrained_or_first_model()
|
|
103
|
+
self._train_processor = encoder.train_preprocess
|
|
104
|
+
if self._test_processor is None:
|
|
105
|
+
self._test_processor = encoder.val_preprocess
|
|
106
|
+
return self._train_processor
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def test_processor(self):
|
|
110
|
+
if self._test_processor is None:
|
|
111
|
+
encoder: ImageEncoder = self.load_pretrained_or_first_model()
|
|
112
|
+
if self._train_processor is None:
|
|
113
|
+
self._train_processor = encoder.train_preprocess
|
|
114
|
+
self._test_processor = encoder.val_preprocess
|
|
115
|
+
return self._test_processor
|
|
116
|
+
|
|
117
|
+
def load_model(
|
|
118
|
+
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
119
|
+
) -> ImageEncoder:
|
|
120
|
+
R"""
|
|
121
|
+
The model config can be:
|
|
122
|
+
|
|
123
|
+
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
124
|
+
- {"model_name": str, "pickle_path": str}, load the model from the binary file (pickle format). This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from model located in the pickle file.
|
|
125
|
+
- {"model_name": str, "state_dict_path": str}, load the model from the state dict file. This will first construct the model using `ImageEncoder(model_name)`, and then load the state dict from the file.
|
|
126
|
+
- Default, load the model using `instantiate` from hydra.
|
|
127
|
+
"""
|
|
128
|
+
if (
|
|
129
|
+
isinstance(model_name_or_config, str)
|
|
130
|
+
and model_name_or_config in self._models
|
|
131
|
+
):
|
|
132
|
+
model_config = self._models[model_name_or_config]
|
|
133
|
+
else:
|
|
134
|
+
model_config = model_name_or_config
|
|
135
|
+
if isinstance(model_config, DictConfig):
|
|
136
|
+
model_config = OmegaConf.to_container(model_config, resolve=True)
|
|
137
|
+
|
|
138
|
+
if isinstance(model_config, str):
|
|
139
|
+
# the model config is a string, which is the path to the model checkpoint in pickle format
|
|
140
|
+
# load the model using `torch.load`
|
|
141
|
+
# this is the original usage in the task arithmetic codebase
|
|
142
|
+
_check_and_redirect_open_clip_modeling()
|
|
143
|
+
log.info(f"loading ImageEncoder from {model_config}")
|
|
144
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
145
|
+
try:
|
|
146
|
+
encoder = torch.load(
|
|
147
|
+
model_config, weights_only=weights_only, *args, **kwargs
|
|
148
|
+
)
|
|
149
|
+
except RuntimeError as e:
|
|
150
|
+
encoder = pickle.load(open(model_config, "rb"))
|
|
151
|
+
elif is_expr_match({"model_name": str, "pickle_path": str}, model_config):
|
|
152
|
+
# the model config is a dictionary with the following keys:
|
|
153
|
+
# - model_name: str, the name of the model
|
|
154
|
+
# - pickle_path: str, the path to the binary file (pickle format)
|
|
155
|
+
# load the model from the binary file (pickle format)
|
|
156
|
+
# this is useful when you use a newer version of torchvision
|
|
157
|
+
_check_and_redirect_open_clip_modeling()
|
|
158
|
+
log.info(
|
|
159
|
+
f"loading ImageEncoder of {model_config['model_name']} from {model_config['pickle_path']}"
|
|
160
|
+
)
|
|
161
|
+
weights_only = kwargs["weights_only"] if "weights_only" in kwargs else False
|
|
162
|
+
try:
|
|
163
|
+
encoder = torch.load(
|
|
164
|
+
model_config["pickle_path"],
|
|
165
|
+
weights_only=weights_only,
|
|
166
|
+
*args,
|
|
167
|
+
**kwargs,
|
|
168
|
+
)
|
|
169
|
+
except RuntimeError as e:
|
|
170
|
+
encoder = pickle.load(open(model_config["pickle_path"], "rb"))
|
|
171
|
+
_encoder = ImageEncoder(model_config["model_name"])
|
|
172
|
+
_encoder.load_state_dict(encoder.state_dict())
|
|
173
|
+
encoder = _encoder
|
|
174
|
+
elif is_expr_match({"model_name": str, "state_dict_path": str}, model_config):
|
|
175
|
+
# the model config is a dictionary with the following keys:
|
|
176
|
+
# - model_name: str, the name of the model
|
|
177
|
+
# - state_dict_path: str, the path to the state dict file
|
|
178
|
+
# load the model from the state dict file
|
|
179
|
+
log.info(
|
|
180
|
+
f"loading ImageEncoder of {model_config['model_name']} from {model_config['state_dict_path']}"
|
|
181
|
+
)
|
|
182
|
+
encoder = ImageEncoder(model_config["model_name"])
|
|
183
|
+
encoder.load_state_dict(
|
|
184
|
+
torch.load(
|
|
185
|
+
model_config["state_dict_path"], weights_only=True, *args, **kwargs
|
|
186
|
+
)
|
|
187
|
+
)
|
|
188
|
+
elif isinstance(model_config, nn.Module):
|
|
189
|
+
# the model config is an existing model
|
|
190
|
+
log.info(f"Returning existing model: {model_config}")
|
|
191
|
+
encoder = model_config
|
|
192
|
+
else:
|
|
193
|
+
encoder = super().load_model(model_name_or_config, *args, **kwargs)
|
|
194
|
+
encoder = cast(ImageEncoder, encoder)
|
|
195
|
+
|
|
196
|
+
# setup the train and test processors
|
|
197
|
+
if self._train_processor is None and hasattr(encoder, "train_preprocess"):
|
|
198
|
+
self._train_processor = encoder.train_preprocess
|
|
199
|
+
if self._test_processor is None and hasattr(encoder, "val_preprocess"):
|
|
200
|
+
self._test_processor = encoder.val_preprocess
|
|
201
|
+
|
|
202
|
+
return encoder
|
|
203
|
+
|
|
204
|
+
def load_classification_head(
|
|
205
|
+
self, model_name_or_config: Union[str, DictConfig], *args, **kwargs
|
|
206
|
+
) -> ClassificationHead:
|
|
207
|
+
R"""
|
|
208
|
+
The model config can be:
|
|
209
|
+
|
|
210
|
+
- A string, which is the path to the model checkpoint in pickle format. Load directly using `torch.load`.
|
|
211
|
+
- Default, load the model using `instantiate` from hydra.
|
|
212
|
+
"""
|
|
213
|
+
if (
|
|
214
|
+
isinstance(model_name_or_config, str)
|
|
215
|
+
and model_name_or_config in self._classification_heads
|
|
216
|
+
):
|
|
217
|
+
model_config = self._classification_heads[model_name_or_config]
|
|
218
|
+
else:
|
|
219
|
+
model_config = model_name_or_config
|
|
220
|
+
|
|
221
|
+
head = load_classifier_head(model_config, *args, **kwargs)
|
|
222
|
+
return head
|
|
223
|
+
|
|
224
|
+
def load_train_dataset(self, dataset_name: str, *args, **kwargs):
|
|
225
|
+
dataset_config = self._train_datasets[dataset_name]
|
|
226
|
+
if isinstance(dataset_config, str):
|
|
227
|
+
log.info(
|
|
228
|
+
f"Loading train dataset using `datasets.load_dataset`: {dataset_config}"
|
|
229
|
+
)
|
|
230
|
+
dataset = load_dataset(dataset_config, split="train")
|
|
231
|
+
else:
|
|
232
|
+
dataset = super().load_train_dataset(dataset_name, *args, **kwargs)
|
|
233
|
+
return dataset
|
|
234
|
+
|
|
235
|
+
def load_val_dataset(self, dataset_name: str, *args, **kwargs):
|
|
236
|
+
dataset_config = self._val_datasets[dataset_name]
|
|
237
|
+
if isinstance(dataset_config, str):
|
|
238
|
+
log.info(
|
|
239
|
+
f"Loading validation dataset using `datasets.load_dataset`: {dataset_config}"
|
|
240
|
+
)
|
|
241
|
+
dataset = load_dataset(dataset_config, split="validation")
|
|
242
|
+
else:
|
|
243
|
+
dataset = super().load_val_dataset(dataset_name, *args, **kwargs)
|
|
244
|
+
return dataset
|
|
245
|
+
|
|
246
|
+
def load_test_dataset(self, dataset_name: str, *args, **kwargs):
|
|
247
|
+
dataset_config = self._test_datasets[dataset_name]
|
|
248
|
+
if isinstance(dataset_config, str):
|
|
249
|
+
log.info(
|
|
250
|
+
f"Loading test dataset using `datasets.load_dataset`: {dataset_config}"
|
|
251
|
+
)
|
|
252
|
+
dataset = load_dataset(dataset_config, split="test")
|
|
253
|
+
else:
|
|
254
|
+
dataset = super().load_test_dataset(dataset_name, *args, **kwargs)
|
|
255
|
+
return dataset
|
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
from typing import Callable, List
|
|
2
|
+
|
|
3
|
+
import open_clip
|
|
4
|
+
import torch
|
|
5
|
+
from torch import Tensor
|
|
6
|
+
|
|
7
|
+
from . import utils
|
|
8
|
+
from .variables_and_paths import CACHEDIR, MODELS, OPENCLIP_CACHEDIR
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ImageEncoder(torch.nn.Module):
|
|
12
|
+
R"""
|
|
13
|
+
Examples:
|
|
14
|
+
|
|
15
|
+
load the image encoder for a given model name
|
|
16
|
+
|
|
17
|
+
>>> from fusion_bench.models.open_clip import ImageEncoder
|
|
18
|
+
>>> image_encoder = ImageEncoder(model_name="ViT-B-32")
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, model_name: str, keep_lang=False):
|
|
22
|
+
super().__init__()
|
|
23
|
+
assert (
|
|
24
|
+
model_name in MODELS
|
|
25
|
+
), f"Invalid model name: {model_name}. Valid models are: {MODELS}"
|
|
26
|
+
|
|
27
|
+
if "__pretrained__" in model_name:
|
|
28
|
+
name, pretrained = model_name.split("__pretrained__")
|
|
29
|
+
elif "__init__" in model_name:
|
|
30
|
+
print("Using random initialization.")
|
|
31
|
+
name, pretrained = model_name.split("__init__")[0], None
|
|
32
|
+
else:
|
|
33
|
+
name = model_name
|
|
34
|
+
pretrained = "openai"
|
|
35
|
+
(
|
|
36
|
+
self.model,
|
|
37
|
+
self.train_preprocess,
|
|
38
|
+
self.val_preprocess,
|
|
39
|
+
) = open_clip.create_model_and_transforms(
|
|
40
|
+
name, pretrained=pretrained, cache_dir=OPENCLIP_CACHEDIR
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
self.cache_dir = CACHEDIR
|
|
44
|
+
|
|
45
|
+
if not keep_lang and hasattr(self.model, "transformer"):
|
|
46
|
+
delattr(self.model, "transformer")
|
|
47
|
+
|
|
48
|
+
def forward(self, images):
|
|
49
|
+
assert self.model is not None
|
|
50
|
+
return self.model.encode_image(images)
|
|
51
|
+
|
|
52
|
+
def __call__(self, inputs):
|
|
53
|
+
return self.forward(inputs)
|
|
54
|
+
|
|
55
|
+
def save(self, filename):
|
|
56
|
+
print(f"Saving image encoder to {filename}")
|
|
57
|
+
utils.torch_save(self, filename)
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def load(cls, model_name, filename):
|
|
61
|
+
print(f"Loading image encoder from {filename}")
|
|
62
|
+
|
|
63
|
+
state_dict = torch.load(filename, map_location="cpu")
|
|
64
|
+
|
|
65
|
+
model = cls(model_name)
|
|
66
|
+
model.load_state_dict(state_dict)
|
|
67
|
+
return model
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class ClassificationHead(torch.nn.Linear):
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
normalize: bool,
|
|
74
|
+
weights: Tensor,
|
|
75
|
+
biases: Tensor = None,
|
|
76
|
+
):
|
|
77
|
+
output_size, input_size = weights.shape
|
|
78
|
+
super().__init__(input_size, output_size)
|
|
79
|
+
self.normalize = normalize
|
|
80
|
+
if weights is not None:
|
|
81
|
+
self.weight = torch.nn.Parameter(weights.clone())
|
|
82
|
+
if biases is not None:
|
|
83
|
+
self.bias = torch.nn.Parameter(biases.clone())
|
|
84
|
+
else:
|
|
85
|
+
self.bias = torch.nn.Parameter(torch.zeros_like(self.bias))
|
|
86
|
+
|
|
87
|
+
def forward(self, inputs: Tensor):
|
|
88
|
+
if self.normalize:
|
|
89
|
+
inputs = inputs / inputs.norm(dim=-1, keepdim=True)
|
|
90
|
+
return super().forward(inputs)
|
|
91
|
+
|
|
92
|
+
def __call__(self, inputs: Tensor):
|
|
93
|
+
return self.forward(inputs)
|
|
94
|
+
|
|
95
|
+
def save(self, filename):
|
|
96
|
+
print(f"Saving classification head to {filename}")
|
|
97
|
+
utils.torch_save(self, filename, save_state_dict=False)
|
|
98
|
+
|
|
99
|
+
@classmethod
|
|
100
|
+
def load(cls, filename):
|
|
101
|
+
# print(f"Loading classification head from {filename}")
|
|
102
|
+
return utils.torch_load(filename)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class ImageClassifier(torch.nn.Module):
|
|
106
|
+
train_preprocess: Callable
|
|
107
|
+
val_preprocess: Callable
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
image_encoder: ImageEncoder,
|
|
112
|
+
classification_head: ClassificationHead,
|
|
113
|
+
):
|
|
114
|
+
super().__init__()
|
|
115
|
+
self.image_encoder = image_encoder
|
|
116
|
+
self.classification_head = classification_head
|
|
117
|
+
if self.image_encoder is not None:
|
|
118
|
+
self.train_preprocess = self.image_encoder.train_preprocess
|
|
119
|
+
self.val_preprocess = self.image_encoder.val_preprocess
|
|
120
|
+
|
|
121
|
+
def freeze_head(self):
|
|
122
|
+
self.classification_head.weight.requires_grad_(False)
|
|
123
|
+
self.classification_head.bias.requires_grad_(False)
|
|
124
|
+
|
|
125
|
+
def forward(self, inputs: Tensor):
|
|
126
|
+
features = self.image_encoder(inputs)
|
|
127
|
+
outputs = self.classification_head(features)
|
|
128
|
+
return outputs
|
|
129
|
+
|
|
130
|
+
def __call__(self, inputs):
|
|
131
|
+
return self.forward(inputs)
|
|
132
|
+
|
|
133
|
+
def save(self, filename):
|
|
134
|
+
print(f"Saving image classifier to {filename}")
|
|
135
|
+
utils.torch_save(self, filename)
|
|
136
|
+
|
|
137
|
+
@classmethod
|
|
138
|
+
def load(cls, filename):
|
|
139
|
+
print(f"Loading image classifier from {filename}")
|
|
140
|
+
return utils.torch_load(filename)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class MultiHeadImageClassifier(torch.nn.Module):
|
|
144
|
+
def __init__(
|
|
145
|
+
self,
|
|
146
|
+
image_encoder: ImageEncoder,
|
|
147
|
+
classification_heads: List[ClassificationHead],
|
|
148
|
+
):
|
|
149
|
+
super().__init__()
|
|
150
|
+
self.image_encoder = image_encoder
|
|
151
|
+
self.classification_heads = torch.nn.ModuleList(classification_heads)
|
|
152
|
+
if self.image_encoder is not None:
|
|
153
|
+
self.train_preprocess = self.image_encoder.train_preprocess
|
|
154
|
+
self.val_preprocess = self.image_encoder.val_preprocess
|
|
155
|
+
|
|
156
|
+
def freeze_head(self):
|
|
157
|
+
for idx in range(len(self.classification_heads)):
|
|
158
|
+
self.classification_heads[idx].weight.requires_grad_(False)
|
|
159
|
+
self.classification_heads[idx].bias.requires_grad_(False)
|
|
160
|
+
|
|
161
|
+
def forward(self, inputs, head_idx):
|
|
162
|
+
features = self.image_encoder(inputs)
|
|
163
|
+
outputs = self.classification_heads[head_idx](features)
|
|
164
|
+
return outputs
|
|
165
|
+
|
|
166
|
+
def __call__(self, inputs, head_idx):
|
|
167
|
+
return self.forward(inputs, head_idx)
|
|
168
|
+
|
|
169
|
+
def save(self, filename):
|
|
170
|
+
print(f"Saving image classifier to {filename}")
|
|
171
|
+
utils.torch_save(self, filename)
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def load(cls, filename):
|
|
175
|
+
print(f"Loading image classifier from {filename}")
|
|
176
|
+
return utils.torch_load(filename)
|