fusion-bench 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/compat/method/__init__.py +1 -0
- fusion_bench/compat/method/base_algorithm.py +7 -1
- fusion_bench/compat/modelpool/__init__.py +1 -1
- fusion_bench/compat/taskpool/__init__.py +1 -1
- fusion_bench/dataset/arc_agi/arc.py +5 -0
- fusion_bench/dataset/arc_agi/preprocess.py +1 -1
- fusion_bench/dataset/llama/__init__.py +1 -0
- fusion_bench/dataset/llama/alpaca.py +93 -3
- fusion_bench/dataset/llama/collate.py +62 -2
- fusion_bench/dataset/llama/metamathqa.py +50 -0
- fusion_bench/dataset/llama/preference_700k.py +70 -0
- fusion_bench/dataset/llama/stanford_shp.py +90 -0
- fusion_bench/dataset/llama/ultrachat.py +58 -0
- fusion_bench/dataset/llama/utils/__init__.py +0 -0
- fusion_bench/method/__init__.py +1 -1
- fusion_bench/method/adamerging/layer_wise_adamerging.py +11 -4
- fusion_bench/method/adamerging/min_norm_solvers.py +4 -4
- fusion_bench/method/linear/expo.py +39 -0
- fusion_bench/method/lm_finetune/__init__.py +1 -0
- fusion_bench/method/lm_finetune/bradley_terry_rm.py +432 -0
- fusion_bench/method/lm_finetune/fullfinetune_sft.py +90 -160
- fusion_bench/method/lm_finetune/peftfinetune_sft.py +49 -139
- fusion_bench/method/pruning/llama_magnitude_prune.py +2 -2
- fusion_bench/method/pruning/llama_random_prune.py +2 -2
- fusion_bench/method/surgery/__init__.py +3 -0
- fusion_bench/method/surgery/clip_layer_wise_adamerging_surgery.py +157 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/clip_classification.py +58 -5
- fusion_bench/mixins/fabric_training.py +320 -0
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/causal_lm/__init__.py +1 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +21 -22
- fusion_bench/modelpool/seq_classification_lm/__init__.py +2 -0
- fusion_bench/modelpool/seq_classification_lm/reward_model.py +15 -0
- fusion_bench/modelpool/seq_classification_lm/seq_classification_lm.py +98 -0
- fusion_bench/models/chat_templates/__init__.py +1 -0
- fusion_bench/models/chat_templates/llama_3_Instruct.py +1 -0
- fusion_bench/models/chat_templates/load_tokenizer.py +43 -0
- fusion_bench/models/hf_clip.py +50 -9
- fusion_bench/models/surgery/surgerymodelwrapper.py +157 -0
- fusion_bench/models/utils.py +8 -0
- fusion_bench/models/wrappers/layer_wise_fusion.py +14 -5
- fusion_bench/models/wrappers/task_wise_fusion.py +5 -5
- fusion_bench/optim/__init__.py +2 -0
- fusion_bench/optim/exception.py +47 -0
- fusion_bench/optim/lr_scheduler/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/linear_warmup.py +222 -0
- fusion_bench/optim/lr_scheduler/utils/__init__.py +1 -0
- fusion_bench/optim/lr_scheduler/utils/visualization.py +119 -0
- fusion_bench/optim/mezo.py +0 -2
- fusion_bench/programs/fabric_fusion_program.py +5 -1
- fusion_bench/taskpool/clip_vision/taskpool.py +43 -6
- fusion_bench/taskpool/llama/reward_model.py +157 -0
- fusion_bench/taskpool/nyuv2_taskpool.py +2 -0
- fusion_bench/utils/hydra_utils.py +22 -0
- fusion_bench/utils/plot/__init__.py +0 -0
- fusion_bench/utils/plot/token.py +52 -0
- fusion_bench/utils/plot/token_notebook.py +127 -0
- fusion_bench/utils/type.py +5 -3
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/RECORD +87 -47
- fusion_bench_config/clip-vit-base-patch32_robustness_corrupted.yaml +1 -1
- fusion_bench_config/dataset/llm_sft/alpaca_cleaned.yaml +6 -0
- fusion_bench_config/dataset/llm_sft/ultrachat_200k.yaml +3 -0
- fusion_bench_config/fabric/llama_peft_fsdp.yaml +16 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric/strategy/deepspeed.yaml +10 -0
- fusion_bench_config/fabric/strategy/llama_peft_fsdp.yaml +9 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -1
- fusion_bench_config/llama_full_finetune.yaml +19 -0
- fusion_bench_config/method/lm_finetune/bradley_terry_rm.yaml +47 -0
- fusion_bench_config/method/lm_finetune/fullfinetune_sft.yaml +11 -4
- fusion_bench_config/method/lm_finetune/peftfinetune_sft.yaml +4 -2
- fusion_bench_config/method/surgery/adamerging_surgery.yaml +27 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_alpaca_cleaned.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_codealpaca.yaml +21 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_metamathqa.yaml +19 -0
- fusion_bench_config/modelpool/CausalLMPool/llama_ultrachat.yaml +18 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/llama_preference700k.yaml +23 -0
- fusion_bench_config/modelpool/SeqenceClassificationModelPool/single_reward_model.yaml +14 -0
- fusion_bench_config/nyuv2_config.yaml +5 -1
- fusion_bench_config/taskpool/reward_model_evaluation.yaml +18 -0
- fusion_bench_config/llama_weighted_average.yaml +0 -26
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/LICENSE +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.6.dist-info → fusion_bench-0.2.7.dist-info}/top_level.txt +0 -0
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from copy import deepcopy
|
|
4
4
|
from typing import Any, Optional, TypeAlias, Union, cast # noqa: F401
|
|
5
5
|
|
|
6
|
+
import peft
|
|
6
7
|
from omegaconf import DictConfig, flag_override
|
|
7
8
|
from torch import nn
|
|
8
9
|
from torch.nn.modules import Module
|
|
@@ -23,28 +24,6 @@ log = logging.getLogger(__name__)
|
|
|
23
24
|
CausalLM: TypeAlias = Union[LlamaForCausalLM, MistralForCausalLM, Any]
|
|
24
25
|
|
|
25
26
|
|
|
26
|
-
def config_priority_get(priority_config, general_config, key, default):
|
|
27
|
-
"""
|
|
28
|
-
Retrieve a configuration value with priority.
|
|
29
|
-
|
|
30
|
-
This function retrieves the value associated with `key` from `priority_config` if it exists.
|
|
31
|
-
If the key is not found in `priority_config`, it retrieves the value from `general_config`.
|
|
32
|
-
If the key is not found in either configuration, it returns the provided `default` value.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
priority_config (dict): The configuration dictionary with higher priority.
|
|
36
|
-
general_config (dict): The general configuration dictionary.
|
|
37
|
-
key (str): The key to look up in the configuration dictionaries.
|
|
38
|
-
default: The default value to return if the key is not found in either configuration.
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
The value associated with `key` from `priority_config` or `general_config`, or the `default` value if the key is not found.
|
|
42
|
-
"""
|
|
43
|
-
if key in priority_config:
|
|
44
|
-
return priority_config[key]
|
|
45
|
-
return general_config.get(key, default)
|
|
46
|
-
|
|
47
|
-
|
|
48
27
|
class CausalLMPool(BaseModelPool):
|
|
49
28
|
_config_mapping = BaseModelPool._config_mapping | {
|
|
50
29
|
"_tokenizer": "tokenizer",
|
|
@@ -138,3 +117,23 @@ class CausalLMBackbonePool(CausalLMPool):
|
|
|
138
117
|
model_name_or_config, *args, **kwargs
|
|
139
118
|
)
|
|
140
119
|
return model.model.layers
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def load_peft_causal_lm(
|
|
123
|
+
base_model_path: str,
|
|
124
|
+
peft_model_path: str,
|
|
125
|
+
torch_dtype: str = "bfloat16",
|
|
126
|
+
is_trainable: bool = True,
|
|
127
|
+
merge_and_unload: bool = False,
|
|
128
|
+
):
|
|
129
|
+
base_model = LlamaForCausalLM.from_pretrained(
|
|
130
|
+
base_model_path, torch_dtype=torch_dtype
|
|
131
|
+
)
|
|
132
|
+
model = peft.PeftModel.from_pretrained(
|
|
133
|
+
base_model,
|
|
134
|
+
peft_model_path,
|
|
135
|
+
is_trainable=is_trainable,
|
|
136
|
+
)
|
|
137
|
+
if merge_and_unload:
|
|
138
|
+
model = model.merge_and_unload()
|
|
139
|
+
return model
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from transformers import AutoModelForSequenceClassification
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def create_reward_model_from_pretrained(pretrained_model_name_or_path: str, **kwargs):
|
|
5
|
+
"""
|
|
6
|
+
Create a reward model for reward modeling (RLHF).
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
pretrained_model_name_or_path (str): The name or path of the pretrained model.
|
|
10
|
+
**kwargs: Additional keyword arguments passed to the model class.
|
|
11
|
+
"""
|
|
12
|
+
model = AutoModelForSequenceClassification.from_pretrained(
|
|
13
|
+
pretrained_model_name_or_path, num_labels=1, **kwargs
|
|
14
|
+
)
|
|
15
|
+
return model
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, TypeAlias, Union, cast # noqa: F401
|
|
5
|
+
|
|
6
|
+
from omegaconf import DictConfig, flag_override
|
|
7
|
+
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
8
|
+
from typing_extensions import override
|
|
9
|
+
|
|
10
|
+
from fusion_bench.modelpool import BaseModelPool
|
|
11
|
+
from fusion_bench.utils import instantiate
|
|
12
|
+
from fusion_bench.utils.dtype import parse_dtype
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from transformers import LlamaForSequenceClassification
|
|
16
|
+
|
|
17
|
+
log = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SeqenceClassificationModelPool(BaseModelPool):
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
models,
|
|
25
|
+
*,
|
|
26
|
+
tokenizer: Optional[DictConfig],
|
|
27
|
+
model_kwargs: Optional[DictConfig] = None,
|
|
28
|
+
**kwargs,
|
|
29
|
+
):
|
|
30
|
+
super().__init__(models, **kwargs)
|
|
31
|
+
# process `model_kwargs`
|
|
32
|
+
self._tokenizer = tokenizer
|
|
33
|
+
self._model_kwargs = model_kwargs
|
|
34
|
+
if self._model_kwargs is None:
|
|
35
|
+
self._model_kwargs = DictConfig({})
|
|
36
|
+
with flag_override(self._model_kwargs, "allow_objects", True):
|
|
37
|
+
if hasattr(self._model_kwargs, "torch_dtype"):
|
|
38
|
+
self._model_kwargs.torch_dtype = parse_dtype(
|
|
39
|
+
self._model_kwargs.torch_dtype
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@override
|
|
43
|
+
def load_model(
|
|
44
|
+
self,
|
|
45
|
+
model_name_or_config: str | DictConfig,
|
|
46
|
+
*args,
|
|
47
|
+
**kwargs,
|
|
48
|
+
) -> Union[PreTrainedModel, "LlamaForSequenceClassification"]:
|
|
49
|
+
model_kwargs = deepcopy(self._model_kwargs)
|
|
50
|
+
model_kwargs.update(kwargs)
|
|
51
|
+
if isinstance(model_name_or_config, str):
|
|
52
|
+
log.info(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
53
|
+
return super().load_model(model_name_or_config, *args, **model_kwargs)
|
|
54
|
+
|
|
55
|
+
def load_tokenizer(self, *args, **kwargs) -> PreTrainedTokenizer:
|
|
56
|
+
assert self._tokenizer is not None, "Tokenizer is not defined in the config"
|
|
57
|
+
log.info("Loading tokenizer.", stacklevel=2)
|
|
58
|
+
tokenizer = instantiate(self._tokenizer, *args, **kwargs)
|
|
59
|
+
return tokenizer
|
|
60
|
+
|
|
61
|
+
@override
|
|
62
|
+
def save_model(
|
|
63
|
+
self,
|
|
64
|
+
model: PreTrainedModel,
|
|
65
|
+
path: str,
|
|
66
|
+
push_to_hub: bool = False,
|
|
67
|
+
model_dtype: Optional[str] = None,
|
|
68
|
+
save_tokenizer: bool = False,
|
|
69
|
+
tokenizer_kwargs=None,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
"""
|
|
73
|
+
Save the model to the specified path.
|
|
74
|
+
|
|
75
|
+
Args:
|
|
76
|
+
model (PreTrainedModel): The model to be saved.
|
|
77
|
+
path (str): The path where the model will be saved.
|
|
78
|
+
push_to_hub (bool, optional): Whether to push the model to the Hugging Face Hub. Defaults to False.
|
|
79
|
+
save_tokenizer (bool, optional): Whether to save the tokenizer along with the model. Defaults to False.
|
|
80
|
+
**kwargs: Additional keyword arguments passed to the `save_pretrained` method.
|
|
81
|
+
"""
|
|
82
|
+
path = os.path.expanduser(path)
|
|
83
|
+
if save_tokenizer:
|
|
84
|
+
if tokenizer_kwargs is None:
|
|
85
|
+
tokenizer_kwargs = {}
|
|
86
|
+
# load the tokenizer
|
|
87
|
+
tokenizer = self.load_tokenizer(**tokenizer_kwargs)
|
|
88
|
+
tokenizer.save_pretrained(
|
|
89
|
+
path,
|
|
90
|
+
push_to_hub=push_to_hub,
|
|
91
|
+
)
|
|
92
|
+
if model_dtype is not None:
|
|
93
|
+
model.to(dtype=parse_dtype(model_dtype))
|
|
94
|
+
model.save_pretrained(
|
|
95
|
+
path,
|
|
96
|
+
push_to_hub=push_to_hub,
|
|
97
|
+
**kwargs,
|
|
98
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .load_tokenizer import chat_template_mapping, load_tokenizer_with_chat_template
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
CHAT_TEMPLATE = '{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now("%d %b %Y") %}\n {%- else %}\n {%- set date_string = "26 Jul 2024" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0][\'role\'] == \'system\' %}\n {%- set system_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0][\'content\']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there\'s no first user message!") }}\n{%- endif %}\n {{- \'<|start_header_id|>user<|end_header_id|>\\n\\n\' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- \'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.\' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == \'ipython\' or message.role == \'tool\' or \'tool_calls\' in message) %}\n {{- \'<|start_header_id|>\' + message[\'role\'] + \'<|end_header_id|>\\n\\n\'+ message[\'content\'] | trim + \'<|eot_id|>\' }}\n {%- elif \'tool_calls\' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' -}}\n {{- \'{"name": "\' + tool_call.name + \'", \' }}\n {{- \'"parameters": \' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {{- "<|eot_id|>" }}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|start_header_id|>assistant<|end_header_id|>\\n\\n\' }}\n{%- endif %}\n'
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
|
|
3
|
+
from transformers import AutoTokenizer
|
|
4
|
+
|
|
5
|
+
from .llama_3_Instruct import CHAT_TEMPLATE as LLAMA_3_INSTRUCT_CHAT_TEMPLATE
|
|
6
|
+
|
|
7
|
+
chat_template_mapping = {"llama_3_instruct": LLAMA_3_INSTRUCT_CHAT_TEMPLATE}
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def load_tokenizer_with_chat_template(
|
|
13
|
+
pretrained_model_name_or_path: str,
|
|
14
|
+
model_family: str,
|
|
15
|
+
overwrite_chat_template: bool = True,
|
|
16
|
+
**kwargs,
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Load the tokenizer for Llama 3 model.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
pretrained_model_name_or_path (str): The name or path of the pretrained model.
|
|
23
|
+
model_family (str): The model family.
|
|
24
|
+
**kwargs: Additional keyword arguments passed to the tokenizer class.
|
|
25
|
+
"""
|
|
26
|
+
assert (
|
|
27
|
+
model_family in chat_template_mapping
|
|
28
|
+
), f"Model family {model_family} not found. Available model families: {chat_template_mapping.keys()}"
|
|
29
|
+
|
|
30
|
+
tokenizer = AutoTokenizer.from_pretrained(
|
|
31
|
+
pretrained_model_name_or_path,
|
|
32
|
+
**kwargs,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if tokenizer.chat_template is None:
|
|
36
|
+
tokenizer.chat_template = chat_template_mapping[model_family]
|
|
37
|
+
else:
|
|
38
|
+
if overwrite_chat_template:
|
|
39
|
+
log.warning("Overwriting the chat template with the default chat template.")
|
|
40
|
+
tokenizer.chat_template = chat_template_mapping[model_family]
|
|
41
|
+
else:
|
|
42
|
+
log.warning("Chat template already exists. Skipping overwriting.")
|
|
43
|
+
return tokenizer
|
fusion_bench/models/hf_clip.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING, Callable, Iterable, List # noqa: F401
|
|
2
3
|
|
|
3
4
|
import torch
|
|
4
5
|
from torch import Tensor, nn
|
|
@@ -7,6 +8,11 @@ from transformers.models.clip.modeling_clip import BaseModelOutputWithPooling
|
|
|
7
8
|
|
|
8
9
|
from fusion_bench.utils.devices import get_device
|
|
9
10
|
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from fusion_bench.models.surgery.surgerymodelwrapper import SurgeryModelWrapper
|
|
13
|
+
|
|
14
|
+
log = logging.getLogger(__name__)
|
|
15
|
+
|
|
10
16
|
default_templates = [
|
|
11
17
|
lambda c: f"a photo of a {c}",
|
|
12
18
|
]
|
|
@@ -33,6 +39,7 @@ class HFCLIPClassifier(nn.Module):
|
|
|
33
39
|
self,
|
|
34
40
|
clip_model: CLIPModel,
|
|
35
41
|
processor: CLIPProcessor,
|
|
42
|
+
extra_module=None,
|
|
36
43
|
):
|
|
37
44
|
"""
|
|
38
45
|
Initialize the HFCLIPClassifier.
|
|
@@ -56,6 +63,8 @@ class HFCLIPClassifier(nn.Module):
|
|
|
56
63
|
persistent=False,
|
|
57
64
|
)
|
|
58
65
|
|
|
66
|
+
self.extra_module = extra_module
|
|
67
|
+
|
|
59
68
|
@property
|
|
60
69
|
def text_model(self):
|
|
61
70
|
"""Get the text model component of CLIP."""
|
|
@@ -111,7 +120,13 @@ class HFCLIPClassifier(nn.Module):
|
|
|
111
120
|
|
|
112
121
|
self.zeroshot_weights = zeroshot_weights
|
|
113
122
|
|
|
114
|
-
def forward(
|
|
123
|
+
def forward(
|
|
124
|
+
self,
|
|
125
|
+
images: Tensor,
|
|
126
|
+
return_image_embeds=False,
|
|
127
|
+
return_dict=False,
|
|
128
|
+
task_name=None,
|
|
129
|
+
):
|
|
115
130
|
"""
|
|
116
131
|
Perform forward pass for zero-shot image classification.
|
|
117
132
|
|
|
@@ -120,6 +135,9 @@ class HFCLIPClassifier(nn.Module):
|
|
|
120
135
|
|
|
121
136
|
Args:
|
|
122
137
|
images (Tensor): Input images to classify.
|
|
138
|
+
return_image_embeds (bool): Whether to return the image embeddings.
|
|
139
|
+
return_dict (bool): Whether to return a dictionary with logits and image embeddings.
|
|
140
|
+
task_name (Optional[str]): The name of the task.
|
|
123
141
|
|
|
124
142
|
Returns:
|
|
125
143
|
Tensor: Classification logits for each input image.
|
|
@@ -131,16 +149,22 @@ class HFCLIPClassifier(nn.Module):
|
|
|
131
149
|
raise ValueError("Must set classification task before forward pass")
|
|
132
150
|
text_embeds = self.zeroshot_weights
|
|
133
151
|
|
|
134
|
-
image_embeds = self.
|
|
135
|
-
if isinstance(image_embeds, Tensor):
|
|
136
|
-
pass
|
|
137
|
-
elif isinstance(image_embeds, BaseModelOutputWithPooling):
|
|
138
|
-
image_embeds = image_embeds[1]
|
|
139
|
-
image_embeds = self.clip_model.visual_projection(image_embeds)
|
|
140
|
-
|
|
152
|
+
image_embeds = self.get_image_features(images)
|
|
141
153
|
# normalize embeddings
|
|
142
154
|
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
143
155
|
|
|
156
|
+
if (
|
|
157
|
+
hasattr(self.vision_model, "is_surgery_model")
|
|
158
|
+
and self.vision_model.is_surgery_model
|
|
159
|
+
):
|
|
160
|
+
# Dealing with the surgery model, for more details, please refer to:
|
|
161
|
+
# (ICML 2024) Yang, et.al. Representation Surgery for Multi-Task Model Merging
|
|
162
|
+
# https://arxiv.org/abs/2402.02705
|
|
163
|
+
self.vision_model: "SurgeryModelWrapper" = self.vision_model
|
|
164
|
+
image_embeds, _, _ = self.vision_model.compute_surgery_features(
|
|
165
|
+
image_embeds, dataset_name=task_name
|
|
166
|
+
)
|
|
167
|
+
|
|
144
168
|
# cosine similarity
|
|
145
169
|
logit_scale = self.clip_model.logit_scale.exp()
|
|
146
170
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
|
@@ -156,3 +180,20 @@ class HFCLIPClassifier(nn.Module):
|
|
|
156
180
|
return logits_per_image, image_embeds
|
|
157
181
|
else:
|
|
158
182
|
return logits_per_image
|
|
183
|
+
|
|
184
|
+
def get_image_features(self, images: Tensor) -> Tensor:
|
|
185
|
+
"""
|
|
186
|
+
Compute the image embeddings.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
|
190
|
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
|
191
|
+
"""
|
|
192
|
+
|
|
193
|
+
image_embeds = self.vision_model(images)
|
|
194
|
+
if isinstance(image_embeds, Tensor):
|
|
195
|
+
pass
|
|
196
|
+
elif isinstance(image_embeds, BaseModelOutputWithPooling):
|
|
197
|
+
image_embeds = image_embeds[1]
|
|
198
|
+
image_embeds = self.clip_model.visual_projection(image_embeds)
|
|
199
|
+
return image_embeds
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import TYPE_CHECKING, List, Union, Callable, Generic
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from torch import nn
|
|
6
|
+
from transformers.models.clip.modeling_clip import (
|
|
7
|
+
CLIPVisionModel,
|
|
8
|
+
CLIPVisionTransformer,
|
|
9
|
+
)
|
|
10
|
+
from fusion_bench.utils.type import TorchModelType
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def regularize_name(name: str):
|
|
14
|
+
name = name.replace("-", "_")
|
|
15
|
+
name = name.replace(".", "_")
|
|
16
|
+
return name
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SurgeryModelWrapper(torch.nn.Module, Generic[TorchModelType]):
|
|
20
|
+
|
|
21
|
+
is_surgery_model = True
|
|
22
|
+
"""A flag to indicate that this is a surgery model."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
model: TorchModelType,
|
|
27
|
+
test_datasets: List[str],
|
|
28
|
+
projection_dim: int = 512,
|
|
29
|
+
hidden_dim: int = 16,
|
|
30
|
+
):
|
|
31
|
+
super(SurgeryModelWrapper, self).__init__()
|
|
32
|
+
self.model = model
|
|
33
|
+
self.model.requires_grad_(False)
|
|
34
|
+
|
|
35
|
+
self.test_datasets = test_datasets
|
|
36
|
+
self.non_linear_func = torch.nn.ReLU()
|
|
37
|
+
|
|
38
|
+
self.projection_dim = projection_dim
|
|
39
|
+
self.hidden_dim = hidden_dim
|
|
40
|
+
|
|
41
|
+
for dataset_name in test_datasets:
|
|
42
|
+
self.add_surgery_module(dataset_name)
|
|
43
|
+
|
|
44
|
+
def add_surgery_module(self, dataset_name: str):
|
|
45
|
+
"""
|
|
46
|
+
Add a surgery module for a given dataset.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
dataset_name (str): The name of the dataset.
|
|
50
|
+
"""
|
|
51
|
+
dataset_name = regularize_name(dataset_name)
|
|
52
|
+
|
|
53
|
+
down_proj = torch.nn.Linear(self.projection_dim, self.hidden_dim, bias=False)
|
|
54
|
+
up_proj = torch.nn.Linear(self.hidden_dim, self.projection_dim, bias=False)
|
|
55
|
+
|
|
56
|
+
torch.nn.init.kaiming_uniform_(down_proj.weight, a=math.sqrt(5))
|
|
57
|
+
torch.nn.init.zeros_(up_proj.weight)
|
|
58
|
+
|
|
59
|
+
self.add_module(
|
|
60
|
+
"feature_mapping_to_head_down_proj_{}".format(dataset_name), down_proj
|
|
61
|
+
)
|
|
62
|
+
self.add_module(
|
|
63
|
+
"feature_mapping_to_head_up_proj_{}".format(dataset_name), up_proj
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
def collect_trainable_params(self):
|
|
67
|
+
trainable_params = []
|
|
68
|
+
|
|
69
|
+
# surgery parameter
|
|
70
|
+
for dataset_name in self.test_datasets:
|
|
71
|
+
dataset_name = regularize_name(dataset_name)
|
|
72
|
+
down_proj = getattr(
|
|
73
|
+
self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
74
|
+
)
|
|
75
|
+
up_proj = getattr(
|
|
76
|
+
self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
77
|
+
)
|
|
78
|
+
trainable_params.append(down_proj.weight)
|
|
79
|
+
trainable_params.append(up_proj.weight)
|
|
80
|
+
return trainable_params
|
|
81
|
+
|
|
82
|
+
def collect_surgery_module(self):
|
|
83
|
+
surgery_module = {}
|
|
84
|
+
|
|
85
|
+
# surgery parameter
|
|
86
|
+
for dataset_name in self.test_datasets:
|
|
87
|
+
dataset_name = regularize_name(dataset_name)
|
|
88
|
+
down_proj = getattr(
|
|
89
|
+
self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
90
|
+
)
|
|
91
|
+
up_proj = getattr(
|
|
92
|
+
self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
93
|
+
)
|
|
94
|
+
surgery_module[
|
|
95
|
+
"feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
96
|
+
] = down_proj
|
|
97
|
+
surgery_module[
|
|
98
|
+
"feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
99
|
+
] = up_proj
|
|
100
|
+
|
|
101
|
+
surgery_module["non_linear_func"] = self.non_linear_func
|
|
102
|
+
|
|
103
|
+
return surgery_module
|
|
104
|
+
|
|
105
|
+
def compute_surgery_features(
|
|
106
|
+
self,
|
|
107
|
+
compute_features_fn: Union[
|
|
108
|
+
torch.Tensor, Callable[[TorchModelType], torch.Tensor]
|
|
109
|
+
],
|
|
110
|
+
dataset_name: str,
|
|
111
|
+
):
|
|
112
|
+
"""
|
|
113
|
+
Compute the surgery features.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
compute_features_fn (Union[torch.Tensor, Callable[[nn.Module], torch.Tensor]]): A function that computes the features or a tensor that represents the features.
|
|
117
|
+
dataset_name (str): The name of the dataset.
|
|
118
|
+
|
|
119
|
+
Returns:
|
|
120
|
+
feature (torch.Tensor): The surgery features.
|
|
121
|
+
feature0 (torch.Tensor): The original features.
|
|
122
|
+
feature_sub (torch.Tensor): feature0 - feature.
|
|
123
|
+
"""
|
|
124
|
+
dataset_name = regularize_name(dataset_name)
|
|
125
|
+
|
|
126
|
+
if isinstance(compute_features_fn, torch.Tensor):
|
|
127
|
+
feature = compute_features_fn
|
|
128
|
+
elif callable(compute_features_fn):
|
|
129
|
+
feature = compute_features_fn(self.model)
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"compute_features_fn must be a tensor or a callable, but got {}".format(
|
|
133
|
+
type(compute_features_fn)
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
feature0 = feature
|
|
138
|
+
|
|
139
|
+
# feature bias
|
|
140
|
+
down_proj = getattr(
|
|
141
|
+
self, "feature_mapping_to_head_down_proj_{}".format(dataset_name)
|
|
142
|
+
)
|
|
143
|
+
up_proj = getattr(
|
|
144
|
+
self, "feature_mapping_to_head_up_proj_{}".format(dataset_name)
|
|
145
|
+
)
|
|
146
|
+
feature_sub = down_proj(feature)
|
|
147
|
+
feature_sub = self.non_linear_func(feature_sub)
|
|
148
|
+
feature_sub = up_proj(feature_sub)
|
|
149
|
+
|
|
150
|
+
# surgery feature
|
|
151
|
+
feature = feature0 - feature_sub
|
|
152
|
+
|
|
153
|
+
return feature, feature0, feature_sub
|
|
154
|
+
|
|
155
|
+
def forward(self, *args, **kwargs):
|
|
156
|
+
"""The wrappered model should just forward like normal."""
|
|
157
|
+
return self.model(*args, **kwargs)
|
fusion_bench/models/utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from typing import List
|
|
2
2
|
|
|
3
|
+
import torch
|
|
3
4
|
from torch import nn
|
|
4
5
|
|
|
5
6
|
|
|
@@ -70,3 +71,10 @@ def find_layers_with_type(
|
|
|
70
71
|
if isinstance(submodule, tuple(layer_types)):
|
|
71
72
|
res[name] = submodule
|
|
72
73
|
return res
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def disable_dropout(model: torch.nn.Module):
|
|
77
|
+
"""Disable dropout in a model."""
|
|
78
|
+
for module in model.modules():
|
|
79
|
+
if isinstance(module, torch.nn.Dropout):
|
|
80
|
+
module.p = 0
|
|
@@ -1,13 +1,22 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import logging
|
|
3
3
|
from copy import deepcopy
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import ( # noqa: F401
|
|
5
|
+
Any,
|
|
6
|
+
Callable,
|
|
7
|
+
Dict,
|
|
8
|
+
Generic,
|
|
9
|
+
Iterator,
|
|
10
|
+
List,
|
|
11
|
+
Optional,
|
|
12
|
+
TypeVar,
|
|
13
|
+
)
|
|
5
14
|
|
|
6
15
|
import torch
|
|
7
16
|
from torch import Tensor, nn
|
|
8
17
|
from torch.func import functional_call
|
|
9
18
|
|
|
10
|
-
from fusion_bench.utils.type import StateDictType
|
|
19
|
+
from fusion_bench.utils.type import TorchModelType, StateDictType
|
|
11
20
|
|
|
12
21
|
__all__ = ["get_layer_wise_weights", "fuse_weights", "LayerWiseMergedModel"]
|
|
13
22
|
|
|
@@ -132,14 +141,14 @@ def fuse_weights(
|
|
|
132
141
|
}
|
|
133
142
|
|
|
134
143
|
|
|
135
|
-
class LayerWiseMergedModel(nn.Module):
|
|
144
|
+
class LayerWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
136
145
|
_merged_state_dict: StateDictType = None
|
|
137
146
|
|
|
138
147
|
def __init__(
|
|
139
148
|
self,
|
|
140
149
|
layer_wise_weight: Tensor,
|
|
141
|
-
pretrained_model:
|
|
142
|
-
finetuned_models: List[
|
|
150
|
+
pretrained_model: TorchModelType,
|
|
151
|
+
finetuned_models: List[TorchModelType],
|
|
143
152
|
clamp_weights: bool = True,
|
|
144
153
|
tie_weights: bool = False,
|
|
145
154
|
strict: bool = True,
|
|
@@ -16,13 +16,13 @@ outputs = merged_model(inputs)
|
|
|
16
16
|
|
|
17
17
|
import functools
|
|
18
18
|
import logging
|
|
19
|
-
from typing import Any, Callable, Dict, Iterator, List, Optional # noqa: F401
|
|
19
|
+
from typing import Any, Callable, Dict, Generic, Iterator, List, Optional # noqa: F401
|
|
20
20
|
|
|
21
21
|
import torch
|
|
22
22
|
from torch import Tensor, nn
|
|
23
23
|
from torch.func import functional_call
|
|
24
24
|
|
|
25
|
-
from fusion_bench.utils.type import StateDictType
|
|
25
|
+
from fusion_bench.utils.type import TorchModelType, StateDictType
|
|
26
26
|
|
|
27
27
|
log = logging.getLogger(__name__)
|
|
28
28
|
|
|
@@ -157,14 +157,14 @@ def fuse_weights(
|
|
|
157
157
|
}
|
|
158
158
|
|
|
159
159
|
|
|
160
|
-
class TaskWiseMergedModel(nn.Module):
|
|
160
|
+
class TaskWiseMergedModel(nn.Module, Generic[TorchModelType]):
|
|
161
161
|
_merged_state_dict: StateDictType = None
|
|
162
162
|
|
|
163
163
|
def __init__(
|
|
164
164
|
self,
|
|
165
165
|
task_wise_weight: Tensor,
|
|
166
|
-
pretrained_model:
|
|
167
|
-
finetuned_models: List[
|
|
166
|
+
pretrained_model: TorchModelType,
|
|
167
|
+
finetuned_models: List[TorchModelType],
|
|
168
168
|
clamp_weights: bool = True,
|
|
169
169
|
tie_weights: bool = False,
|
|
170
170
|
strict: bool = True,
|
fusion_bench/optim/__init__.py
CHANGED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
class NoSparseGradientError(Exception):
|
|
2
|
+
"""Raised when the gradient is sparse gradient.
|
|
3
|
+
|
|
4
|
+
:param optimizer_name: str. optimizer name.
|
|
5
|
+
:param note: str. special conditions to note (default '').
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self, optimizer_name: str, note: str = ""):
|
|
9
|
+
self.note: str = " " if not note else f" w/ {note} "
|
|
10
|
+
self.message: str = (
|
|
11
|
+
f"[-] {optimizer_name}{self.note}does not support sparse gradient."
|
|
12
|
+
)
|
|
13
|
+
super().__init__(self.message)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ZeroParameterSizeError(Exception):
|
|
17
|
+
"""Raised when the parameter size is 0."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
self.message: str = "[-] parameter size is 0"
|
|
21
|
+
super().__init__(self.message)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class NoClosureError(Exception):
|
|
25
|
+
"""Raised when there's no closure function."""
|
|
26
|
+
|
|
27
|
+
def __init__(self, optimizer_name: str, note: str = ""):
|
|
28
|
+
self.message: str = f"[-] {optimizer_name} requires closure.{note}"
|
|
29
|
+
super().__init__(self.message)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class NegativeLRError(Exception):
|
|
33
|
+
"""Raised when learning rate is negative."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, lr: float, lr_type: str = ""):
|
|
36
|
+
self.note: str = lr_type if lr_type else "learning rate"
|
|
37
|
+
self.message: str = f"[-] {self.note} must be positive. ({lr} > 0)"
|
|
38
|
+
super().__init__(self.message)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class NegativeStepError(Exception):
|
|
42
|
+
"""Raised when step is negative."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, num_steps: int, step_type: str = ""):
|
|
45
|
+
self.note: str = step_type if step_type else "step"
|
|
46
|
+
self.message: str = f"[-] {self.note} must be positive. ({num_steps} > 0)"
|
|
47
|
+
super().__init__(self.message)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .linear_warmup import *
|