fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +18 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +58 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +8 -3
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -5,46 +5,156 @@
|
|
|
5
5
|
# ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
|
|
6
6
|
# ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
|
|
7
7
|
# flake8: noqa: F401
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
modelpool,
|
|
15
|
-
models,
|
|
16
|
-
optim,
|
|
17
|
-
programs,
|
|
18
|
-
taskpool,
|
|
19
|
-
tasks,
|
|
20
|
-
utils,
|
|
21
|
-
)
|
|
8
|
+
import sys
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
12
|
+
|
|
13
|
+
from . import constants, metrics, optim, tasks
|
|
22
14
|
from .constants import RuntimeConstants
|
|
23
|
-
from .method import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
15
|
+
from .method import _available_algorithms
|
|
16
|
+
|
|
17
|
+
_extra_objects = {
|
|
18
|
+
"RuntimeConstants": RuntimeConstants,
|
|
19
|
+
"constants": constants,
|
|
20
|
+
"metrics": metrics,
|
|
21
|
+
"optim": optim,
|
|
22
|
+
"tasks": tasks,
|
|
23
|
+
}
|
|
24
|
+
_import_structure = {
|
|
25
|
+
"dataset": ["CLIPDataset"],
|
|
26
|
+
"method": _available_algorithms,
|
|
27
|
+
"mixins": [
|
|
28
|
+
"CLIPClassificationMixin",
|
|
29
|
+
"FabricTrainingMixin",
|
|
30
|
+
"HydraConfigMixin",
|
|
31
|
+
"LightningFabricMixin",
|
|
32
|
+
"OpenCLIPClassificationMixin",
|
|
33
|
+
"PyinstrumentProfilerMixin",
|
|
34
|
+
"SimpleProfilerMixin",
|
|
35
|
+
"YAMLSerializationMixin",
|
|
36
|
+
"auto_register_config",
|
|
37
|
+
],
|
|
38
|
+
"modelpool": [
|
|
39
|
+
"AutoModelPool",
|
|
40
|
+
"BaseModelPool",
|
|
41
|
+
"CausalLMBackbonePool",
|
|
42
|
+
"CausalLMPool",
|
|
43
|
+
"CLIPVisionModelPool",
|
|
44
|
+
"GPT2ForSequenceClassificationPool",
|
|
45
|
+
"HuggingFaceGPT2ClassificationPool",
|
|
46
|
+
"NYUv2ModelPool",
|
|
47
|
+
"OpenCLIPVisionModelPool",
|
|
48
|
+
"PeftModelForSeq2SeqLMPool",
|
|
49
|
+
"ResNetForImageClassificationPool",
|
|
50
|
+
"Seq2SeqLMPool",
|
|
51
|
+
"SequenceClassificationModelPool",
|
|
52
|
+
],
|
|
53
|
+
"models": [
|
|
54
|
+
"create_default_model_card",
|
|
55
|
+
"load_model_card_template",
|
|
56
|
+
"save_pretrained_with_remote_code",
|
|
57
|
+
"separate_load",
|
|
58
|
+
"separate_save",
|
|
59
|
+
],
|
|
60
|
+
"programs": ["BaseHydraProgram", "FabricModelFusionProgram"],
|
|
61
|
+
"taskpool": [
|
|
62
|
+
"BaseTaskPool",
|
|
63
|
+
"CLIPVisionModelTaskPool",
|
|
64
|
+
"DummyTaskPool",
|
|
65
|
+
"GPT2TextClassificationTaskPool",
|
|
66
|
+
"LMEvalHarnessTaskPool",
|
|
67
|
+
"OpenCLIPVisionModelTaskPool",
|
|
68
|
+
"NYUv2TaskPool",
|
|
69
|
+
],
|
|
70
|
+
"utils": [
|
|
71
|
+
"ArithmeticStateDict",
|
|
72
|
+
"BoolStateDictType",
|
|
73
|
+
"LazyStateDict",
|
|
74
|
+
"StateDictType",
|
|
75
|
+
"TorchModelType",
|
|
76
|
+
"cache_with_joblib",
|
|
77
|
+
"get_rankzero_logger",
|
|
78
|
+
"import_object",
|
|
79
|
+
"instantiate",
|
|
80
|
+
"parse_dtype",
|
|
81
|
+
"print_parameters",
|
|
82
|
+
"seed_everything_by_time",
|
|
83
|
+
"set_default_cache_dir",
|
|
84
|
+
"set_print_function_call",
|
|
85
|
+
"set_print_function_call_permeanent",
|
|
86
|
+
"timeit_context",
|
|
87
|
+
],
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
if TYPE_CHECKING:
|
|
91
|
+
from .dataset import CLIPDataset
|
|
92
|
+
from .method import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
93
|
+
from .mixins import (
|
|
94
|
+
CLIPClassificationMixin,
|
|
95
|
+
FabricTrainingMixin,
|
|
96
|
+
HydraConfigMixin,
|
|
97
|
+
LightningFabricMixin,
|
|
98
|
+
OpenCLIPClassificationMixin,
|
|
99
|
+
PyinstrumentProfilerMixin,
|
|
100
|
+
SimpleProfilerMixin,
|
|
101
|
+
YAMLSerializationMixin,
|
|
102
|
+
auto_register_config,
|
|
103
|
+
)
|
|
104
|
+
from .modelpool import (
|
|
105
|
+
AutoModelPool,
|
|
106
|
+
BaseModelPool,
|
|
107
|
+
CausalLMBackbonePool,
|
|
108
|
+
CausalLMPool,
|
|
109
|
+
CLIPVisionModelPool,
|
|
110
|
+
GPT2ForSequenceClassificationPool,
|
|
111
|
+
HuggingFaceGPT2ClassificationPool,
|
|
112
|
+
NYUv2ModelPool,
|
|
113
|
+
OpenCLIPVisionModelPool,
|
|
114
|
+
PeftModelForSeq2SeqLMPool,
|
|
115
|
+
ResNetForImageClassificationPool,
|
|
116
|
+
Seq2SeqLMPool,
|
|
117
|
+
SequenceClassificationModelPool,
|
|
118
|
+
)
|
|
119
|
+
from .models import (
|
|
120
|
+
create_default_model_card,
|
|
121
|
+
load_model_card_template,
|
|
122
|
+
save_pretrained_with_remote_code,
|
|
123
|
+
separate_load,
|
|
124
|
+
separate_save,
|
|
125
|
+
)
|
|
126
|
+
from .programs import BaseHydraProgram, FabricModelFusionProgram
|
|
127
|
+
from .taskpool import (
|
|
128
|
+
BaseTaskPool,
|
|
129
|
+
CLIPVisionModelTaskPool,
|
|
130
|
+
DummyTaskPool,
|
|
131
|
+
GPT2TextClassificationTaskPool,
|
|
132
|
+
LMEvalHarnessTaskPool,
|
|
133
|
+
NYUv2TaskPool,
|
|
134
|
+
OpenCLIPVisionModelTaskPool,
|
|
135
|
+
)
|
|
136
|
+
from .utils import (
|
|
137
|
+
ArithmeticStateDict,
|
|
138
|
+
BoolStateDictType,
|
|
139
|
+
LazyStateDict,
|
|
140
|
+
StateDictType,
|
|
141
|
+
TorchModelType,
|
|
142
|
+
cache_with_joblib,
|
|
143
|
+
get_rankzero_logger,
|
|
144
|
+
import_object,
|
|
145
|
+
instantiate,
|
|
146
|
+
parse_dtype,
|
|
147
|
+
print_parameters,
|
|
148
|
+
seed_everything_by_time,
|
|
149
|
+
set_default_cache_dir,
|
|
150
|
+
set_print_function_call,
|
|
151
|
+
set_print_function_call_permeanent,
|
|
152
|
+
timeit_context,
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
sys.modules[__name__] = LazyImporter(
|
|
156
|
+
__name__,
|
|
157
|
+
globals()["__file__"],
|
|
158
|
+
_import_structure,
|
|
159
|
+
extra_objects=_extra_objects,
|
|
160
|
+
)
|
fusion_bench/dataset/__init__.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from omegaconf import DictConfig, open_dict
|
|
6
6
|
|
|
7
|
-
from .
|
|
7
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def load_dataset_from_config(dataset_config: DictConfig):
|
|
11
11
|
"""
|
|
12
12
|
Load the dataset from the configuration.
|
|
13
13
|
"""
|
|
14
|
+
from datasets import load_dataset
|
|
15
|
+
|
|
16
|
+
from fusion_bench.utils import instantiate
|
|
17
|
+
|
|
14
18
|
assert hasattr(dataset_config, "type"), "Dataset type not specified"
|
|
15
19
|
if dataset_config.type == "instantiate":
|
|
16
20
|
return instantiate(dataset_config.object)
|
|
@@ -27,3 +31,22 @@ def load_dataset_from_config(dataset_config: DictConfig):
|
|
|
27
31
|
return dataset
|
|
28
32
|
else:
|
|
29
33
|
raise ValueError(f"Unknown dataset type: {dataset_config.type}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_extra_objects = {
|
|
37
|
+
"load_dataset_from_config": load_dataset_from_config,
|
|
38
|
+
}
|
|
39
|
+
_import_structure = {
|
|
40
|
+
"clip_dataset": ["CLIPDataset"],
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from .clip_dataset import CLIPDataset
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
sys.modules[__name__] = LazyImporter(
|
|
48
|
+
__name__,
|
|
49
|
+
globals()["__file__"],
|
|
50
|
+
_import_structure,
|
|
51
|
+
extra_objects=_extra_objects,
|
|
52
|
+
)
|
|
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
|
-
from transformers import CLIPProcessor, ProcessorMixin
|
|
9
|
+
from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
|
|
10
10
|
|
|
11
11
|
__all__ = ["CLIPDataset"]
|
|
12
12
|
|
|
@@ -60,7 +60,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
60
60
|
raise ValueError("Each item should be a dictionary or a tuple of length 2")
|
|
61
61
|
image = item["image"]
|
|
62
62
|
if self.processor is not None:
|
|
63
|
-
if isinstance(self.processor, ProcessorMixin):
|
|
63
|
+
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
64
|
# Apply the processor to the image to get the input tensor
|
|
65
65
|
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
66
66
|
"pixel_values"
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
import sys
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
+
from fusion_bench.utils import join_lists
|
|
5
6
|
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
7
|
|
|
7
8
|
_import_structure = {
|
|
@@ -12,6 +13,8 @@ _import_structure = {
|
|
|
12
13
|
"classification": [
|
|
13
14
|
"ImageClassificationFineTuningForCLIP",
|
|
14
15
|
"ContinualImageClassificationFineTuningForCLIP",
|
|
16
|
+
"ImageClassificationFineTuning",
|
|
17
|
+
"ImageClassificationFineTuning_Test",
|
|
15
18
|
],
|
|
16
19
|
"lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
|
|
17
20
|
# analysis
|
|
@@ -26,9 +29,12 @@ _import_structure = {
|
|
|
26
29
|
"linear": [
|
|
27
30
|
"ExPOAlgorithm",
|
|
28
31
|
"ExPOAlgorithmForLlama",
|
|
32
|
+
"SimpleAverageForCausalLM",
|
|
29
33
|
"SimpleAverageForLlama",
|
|
34
|
+
"TaskArithmeticForCausalLM",
|
|
30
35
|
"TaskArithmeticForLlama",
|
|
31
36
|
"LinearInterpolationAlgorithm",
|
|
37
|
+
"TiesMergingForCausalLM",
|
|
32
38
|
],
|
|
33
39
|
"slerp": ["SlerpMergeAlgorithm", "SlerpForCausalLM"],
|
|
34
40
|
"simple_average": ["SimpleAverageAlgorithm"],
|
|
@@ -72,6 +78,7 @@ _import_structure = {
|
|
|
72
78
|
"fw_merging": ["FrankWolfeHardAlgorithm", "FrankWolfeSoftAlgorithm"],
|
|
73
79
|
"tall_mask": ["TallMaskTaskArithmeticAlgorithm"],
|
|
74
80
|
"model_stock": ["ModelStock"],
|
|
81
|
+
"wudi": ["wudi_merging", "WUDIMerging"],
|
|
75
82
|
# plug-and-play model merging methods
|
|
76
83
|
"concrete_subspace": [
|
|
77
84
|
"ConcreteTaskArithmeticAlgorithmForCLIP",
|
|
@@ -127,7 +134,10 @@ _import_structure = {
|
|
|
127
134
|
"ProgressivePruningForMixtral",
|
|
128
135
|
],
|
|
129
136
|
}
|
|
130
|
-
|
|
137
|
+
_available_algorithms = join_lists(list(_import_structure.values()))
|
|
138
|
+
_extra_objects = {
|
|
139
|
+
"_available_algorithms": _available_algorithms,
|
|
140
|
+
}
|
|
131
141
|
|
|
132
142
|
if TYPE_CHECKING:
|
|
133
143
|
from .ada_svd import AdaSVDMergingForCLIPVisionModel
|
|
@@ -137,6 +147,8 @@ if TYPE_CHECKING:
|
|
|
137
147
|
from .bitdelta import BitDeltaAlgorithm
|
|
138
148
|
from .classification import (
|
|
139
149
|
ContinualImageClassificationFineTuningForCLIP,
|
|
150
|
+
ImageClassificationFineTuning,
|
|
151
|
+
ImageClassificationFineTuning_Test,
|
|
140
152
|
ImageClassificationFineTuningForCLIP,
|
|
141
153
|
)
|
|
142
154
|
from .concrete_subspace import (
|
|
@@ -184,8 +196,11 @@ if TYPE_CHECKING:
|
|
|
184
196
|
ExPOAlgorithm,
|
|
185
197
|
ExPOAlgorithmForLlama,
|
|
186
198
|
LinearInterpolationAlgorithm,
|
|
199
|
+
SimpleAverageForCausalLM,
|
|
187
200
|
SimpleAverageForLlama,
|
|
201
|
+
TaskArithmeticForCausalLM,
|
|
188
202
|
TaskArithmeticForLlama,
|
|
203
|
+
TiesMergingForCausalLM,
|
|
189
204
|
)
|
|
190
205
|
from .lm_finetune import *
|
|
191
206
|
from .mixture_of_experts import (
|
|
@@ -238,10 +253,12 @@ if TYPE_CHECKING:
|
|
|
238
253
|
FlanT5WeightEnsemblingMoEAlgorithm,
|
|
239
254
|
)
|
|
240
255
|
from .weighted_average import WeightedAverageAlgorithm, WeightedAverageForLLama
|
|
256
|
+
from .wudi import WUDIMerging, wudi_merging
|
|
241
257
|
|
|
242
258
|
else:
|
|
243
259
|
sys.modules[__name__] = LazyImporter(
|
|
244
260
|
__name__,
|
|
245
261
|
globals()["__file__"],
|
|
246
262
|
_import_structure,
|
|
263
|
+
extra_objects=_extra_objects,
|
|
247
264
|
)
|
|
@@ -1,3 +1,28 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
|
+
|
|
7
|
+
_import_structure = {
|
|
8
|
+
"clip_finetune": ["ImageClassificationFineTuningForCLIP"],
|
|
9
|
+
"continual_clip_finetune": ["ContinualImageClassificationFineTuningForCLIP"],
|
|
10
|
+
"image_classification_finetune": [
|
|
11
|
+
"ImageClassificationFineTuning",
|
|
12
|
+
"ImageClassificationFineTuning_Test",
|
|
13
|
+
],
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .clip_finetune import ImageClassificationFineTuningForCLIP
|
|
18
|
+
from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
|
|
19
|
+
from .image_classification_finetune import (
|
|
20
|
+
ImageClassificationFineTuning,
|
|
21
|
+
ImageClassificationFineTuning_Test,
|
|
22
|
+
)
|
|
23
|
+
else:
|
|
24
|
+
sys.modules[__name__] = LazyImporter(
|
|
25
|
+
__name__,
|
|
26
|
+
globals()["__file__"],
|
|
27
|
+
_import_structure,
|
|
28
|
+
)
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import lightning as L
|
|
5
|
+
import lightning.pytorch.callbacks as pl_callbacks
|
|
6
|
+
import torch
|
|
7
|
+
from lightning.pytorch.loggers import TensorBoardLogger
|
|
8
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
9
|
+
from lit_learn.lit_modules import ERM_LitModule
|
|
10
|
+
from omegaconf import DictConfig
|
|
11
|
+
from torch import nn
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from torchmetrics.classification import Accuracy
|
|
14
|
+
|
|
15
|
+
from fusion_bench import (
|
|
16
|
+
BaseAlgorithm,
|
|
17
|
+
BaseModelPool,
|
|
18
|
+
RuntimeConstants,
|
|
19
|
+
auto_register_config,
|
|
20
|
+
get_rankzero_logger,
|
|
21
|
+
instantiate,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.dataset import CLIPDataset
|
|
24
|
+
from fusion_bench.modelpool import ResNetForImageClassificationPool
|
|
25
|
+
from fusion_bench.tasks.clip_classification import get_num_classes
|
|
26
|
+
|
|
27
|
+
log = get_rankzero_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@auto_register_config
|
|
31
|
+
class ImageClassificationFineTuning(BaseAlgorithm):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
max_epochs: Optional[int],
|
|
35
|
+
max_steps: Optional[int],
|
|
36
|
+
label_smoothing: float,
|
|
37
|
+
optimizer: DictConfig,
|
|
38
|
+
lr_scheduler: DictConfig,
|
|
39
|
+
dataloader_kwargs: DictConfig,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
assert (max_epochs is None) or (
|
|
44
|
+
max_steps is None or max_steps < 0
|
|
45
|
+
), "Only one of max_epochs or max_steps should be set."
|
|
46
|
+
self.training_interval = "epoch" if max_epochs is not None else "step"
|
|
47
|
+
if self.training_interval == "epoch":
|
|
48
|
+
self.max_steps = -1
|
|
49
|
+
log.info(f"Training interval: {self.training_interval}")
|
|
50
|
+
log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
|
|
51
|
+
|
|
52
|
+
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
53
|
+
# load model and dataset
|
|
54
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
55
|
+
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
56
|
+
|
|
57
|
+
assert (
|
|
58
|
+
len(modelpool.train_dataset_names) == 1
|
|
59
|
+
), "Exactly one training dataset is required."
|
|
60
|
+
self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
|
|
61
|
+
num_classes = get_num_classes(dataset_name)
|
|
62
|
+
train_dataset = modelpool.load_train_dataset(dataset_name)
|
|
63
|
+
train_dataset = CLIPDataset(
|
|
64
|
+
train_dataset, processor=modelpool.load_processor(stage="train")
|
|
65
|
+
)
|
|
66
|
+
train_loader = self.get_dataloader(train_dataset, stage="train")
|
|
67
|
+
if modelpool.has_val_dataset:
|
|
68
|
+
val_dataset = modelpool.load_val_dataset(dataset_name)
|
|
69
|
+
val_dataset = CLIPDataset(
|
|
70
|
+
val_dataset, processor=modelpool.load_processor(stage="val")
|
|
71
|
+
)
|
|
72
|
+
val_loader = self.get_dataloader(val_dataset, stage="val")
|
|
73
|
+
|
|
74
|
+
# configure optimizer
|
|
75
|
+
optimizer = instantiate(self.optimizer, params=model.parameters())
|
|
76
|
+
if self.lr_scheduler is not None:
|
|
77
|
+
lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
|
|
78
|
+
optimizer = {
|
|
79
|
+
"optimizer": optimizer,
|
|
80
|
+
"lr_scheduler": {
|
|
81
|
+
"scheduler": lr_scheduler,
|
|
82
|
+
"interval": self.training_interval,
|
|
83
|
+
"frequency": 1,
|
|
84
|
+
},
|
|
85
|
+
}
|
|
86
|
+
log.info(f"optimizer:\n{optimizer}")
|
|
87
|
+
|
|
88
|
+
lit_module = ERM_LitModule(
|
|
89
|
+
model,
|
|
90
|
+
optimizer,
|
|
91
|
+
objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
|
|
92
|
+
metrics={
|
|
93
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
94
|
+
"acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
log_dir = (
|
|
99
|
+
self._program.path.log_dir
|
|
100
|
+
if self._program is not None
|
|
101
|
+
else "outputs/lightning_logs"
|
|
102
|
+
)
|
|
103
|
+
trainer = L.Trainer(
|
|
104
|
+
max_epochs=self.max_epochs,
|
|
105
|
+
max_steps=self.max_steps,
|
|
106
|
+
accelerator="auto",
|
|
107
|
+
devices="auto",
|
|
108
|
+
callbacks=[
|
|
109
|
+
pl_callbacks.LearningRateMonitor(logging_interval="step"),
|
|
110
|
+
pl_callbacks.DeviceStatsMonitor(),
|
|
111
|
+
],
|
|
112
|
+
logger=TensorBoardLogger(
|
|
113
|
+
save_dir=log_dir,
|
|
114
|
+
name="",
|
|
115
|
+
),
|
|
116
|
+
fast_dev_run=RuntimeConstants.debug,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
trainer.fit(
|
|
120
|
+
lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
|
|
121
|
+
)
|
|
122
|
+
model = lit_module.model
|
|
123
|
+
if rank_zero_only.rank == 0:
|
|
124
|
+
log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
|
|
125
|
+
modelpool.save_model(
|
|
126
|
+
model,
|
|
127
|
+
path=os.path.join(
|
|
128
|
+
trainer.log_dir if trainer.log_dir is not None else log_dir,
|
|
129
|
+
"raw_checkpoints",
|
|
130
|
+
"final",
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
return model
|
|
134
|
+
|
|
135
|
+
def get_dataloader(self, dataset, stage: str):
|
|
136
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
137
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
138
|
+
if "shuffle" not in dataloader_kwargs:
|
|
139
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
140
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@auto_register_config
|
|
144
|
+
class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
145
|
+
def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
|
|
146
|
+
super().__init__(**kwargs)
|
|
147
|
+
|
|
148
|
+
def run(self, modelpool: BaseModelPool):
|
|
149
|
+
assert (
|
|
150
|
+
modelpool.has_val_dataset or modelpool.has_test_dataset
|
|
151
|
+
), "No validation or test dataset found in the model pool."
|
|
152
|
+
|
|
153
|
+
# load model and dataset
|
|
154
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
155
|
+
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
156
|
+
|
|
157
|
+
if modelpool.has_test_dataset:
|
|
158
|
+
assert (
|
|
159
|
+
len(modelpool.test_dataset_names) == 1
|
|
160
|
+
), "Exactly one test dataset is required."
|
|
161
|
+
self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
|
|
162
|
+
dataset = modelpool.load_test_dataset(dataset_name)
|
|
163
|
+
dataset = CLIPDataset(
|
|
164
|
+
dataset, processor=modelpool.load_processor(stage="test")
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
assert (
|
|
168
|
+
len(modelpool.val_dataset_names) == 1
|
|
169
|
+
), "Exactly one validation dataset is required."
|
|
170
|
+
self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
|
|
171
|
+
dataset = modelpool.load_val_dataset(dataset_name)
|
|
172
|
+
dataset = CLIPDataset(
|
|
173
|
+
dataset, processor=modelpool.load_processor(stage="test")
|
|
174
|
+
)
|
|
175
|
+
num_classes = get_num_classes(dataset_name)
|
|
176
|
+
|
|
177
|
+
test_loader = self.get_dataloader(dataset, stage="test")
|
|
178
|
+
|
|
179
|
+
if self.checkpoint_path is None:
|
|
180
|
+
lit_module = ERM_LitModule(
|
|
181
|
+
model,
|
|
182
|
+
metrics={
|
|
183
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
184
|
+
"acc@5": Accuracy(
|
|
185
|
+
task="multiclass", num_classes=num_classes, top_k=5
|
|
186
|
+
),
|
|
187
|
+
},
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
lit_module = ERM_LitModule.load_from_checkpoint(
|
|
191
|
+
checkpoint_path=self.checkpoint_path,
|
|
192
|
+
model=model,
|
|
193
|
+
metrics={
|
|
194
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
195
|
+
"acc@5": Accuracy(
|
|
196
|
+
task="multiclass", num_classes=num_classes, top_k=5
|
|
197
|
+
),
|
|
198
|
+
},
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
trainer = L.Trainer(
|
|
202
|
+
devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
test_metrics = trainer.test(lit_module, dataloaders=test_loader)
|
|
206
|
+
log.info(f"Test metrics: {test_metrics}")
|
|
207
|
+
return model
|
|
208
|
+
|
|
209
|
+
def get_dataloader(self, dataset, stage: str):
|
|
210
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
211
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
212
|
+
if "shuffle" not in dataloader_kwargs:
|
|
213
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
214
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
fusion_bench/method/ensemble.py
CHANGED
|
@@ -17,7 +17,21 @@ from fusion_bench.models.wrappers.ensemble import (
|
|
|
17
17
|
log = logging.getLogger(__name__)
|
|
18
18
|
|
|
19
19
|
|
|
20
|
+
@auto_register_config
|
|
20
21
|
class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
device_map: Optional[Mapping[int, Union[str, torch.device]]] = None,
|
|
25
|
+
**kwargs,
|
|
26
|
+
):
|
|
27
|
+
"""
|
|
28
|
+
Initializes the SimpleEnsembleAlgorithm with an optional device map.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
device_map (Optional[Mapping[int, Union[str, torch.device]]], optional): A mapping from model index to device. Defaults to None.
|
|
32
|
+
"""
|
|
33
|
+
super().__init__(**kwargs)
|
|
34
|
+
|
|
21
35
|
@torch.no_grad()
|
|
22
36
|
def run(self, modelpool: BaseModelPool | List[nn.Module]) -> EnsembleModule:
|
|
23
37
|
"""
|
|
@@ -30,9 +44,10 @@ class SimpleEnsembleAlgorithm(BaseAlgorithm):
|
|
|
30
44
|
EnsembleModule: The ensembled model.
|
|
31
45
|
"""
|
|
32
46
|
log.info(f"Running ensemble algorithm with {len(modelpool)} models")
|
|
33
|
-
|
|
34
47
|
models = [modelpool.load_model(m) for m in modelpool.model_names]
|
|
35
|
-
|
|
48
|
+
|
|
49
|
+
log.info("creating ensemble module")
|
|
50
|
+
ensemble = EnsembleModule(models=models, device_map=self.device_map)
|
|
36
51
|
return ensemble
|
|
37
52
|
|
|
38
53
|
|
|
@@ -2,5 +2,9 @@
|
|
|
2
2
|
from .expo import ExPOAlgorithm
|
|
3
3
|
from .linear_interpolation import LinearInterpolationAlgorithm
|
|
4
4
|
from .llama_expo import ExPOAlgorithmForLlama
|
|
5
|
-
from .
|
|
6
|
-
from .
|
|
5
|
+
from .simple_average_for_causallm import SimpleAverageForCausalLM, SimpleAverageForLlama
|
|
6
|
+
from .task_arithmetic_for_causallm import (
|
|
7
|
+
TaskArithmeticForCausalLM,
|
|
8
|
+
TaskArithmeticForLlama,
|
|
9
|
+
)
|
|
10
|
+
from .ties_merging_for_causallm import TiesMergingForCausalLM
|
|
@@ -18,16 +18,16 @@ log = get_rankzero_logger(__name__)
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
@auto_register_config
|
|
21
|
-
class
|
|
21
|
+
class SimpleAverageForCausalLM(BaseAlgorithm):
|
|
22
22
|
R"""
|
|
23
23
|
A simple averaging algorithm for LLama models. If `merge_backbone` is set to `True`, the backbone of the model will be averaged and the rest of the model will be loaded from the pre-trained model.
|
|
24
24
|
|
|
25
25
|
Examples:
|
|
26
|
-
The following example demonstrates how to use the `
|
|
26
|
+
The following example demonstrates how to use the `SimpleAverageForCausalLM` algorithm to merge Mistral models.
|
|
27
27
|
|
|
28
28
|
```bash
|
|
29
29
|
fusion_bench \
|
|
30
|
-
method=linear/
|
|
30
|
+
method=linear/simple_average_for_causallm \
|
|
31
31
|
method.model_save_path=outputs/simle_mixtral_exp_v4/simple_average \
|
|
32
32
|
modelpool=CausalLMPool/simle_mixtral_exp_v4.yaml
|
|
33
33
|
```
|
|
@@ -35,7 +35,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
35
35
|
|
|
36
36
|
def __init__(
|
|
37
37
|
self,
|
|
38
|
-
merge_backbone: bool,
|
|
38
|
+
merge_backbone: bool = False,
|
|
39
39
|
model_save_path: Optional[str] = None,
|
|
40
40
|
show_pbar: bool = False,
|
|
41
41
|
**kwargs,
|
|
@@ -81,3 +81,7 @@ class SimpleAverageForLlama(BaseAlgorithm):
|
|
|
81
81
|
with open(os.path.join(self.model_save_path, "README.md"), "w") as f:
|
|
82
82
|
f.write(model_card_str)
|
|
83
83
|
return model
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
SimpleAverageForLlama = SimpleAverageForCausalLM
|
|
87
|
+
"""Alias for SimpleAverageForCausalLM"""
|