fusion-bench 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +12 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/clip_finetune.py +6 -4
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/dop/__init__.py +1 -0
- fusion_bench/method/dop/dop.py +366 -0
- fusion_bench/method/dop/min_norm_solvers.py +227 -0
- fusion_bench/method/dop/utils.py +73 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +55 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +1 -0
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/METADATA +8 -2
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/RECORD +75 -56
- fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/depth_upscaling.yaml +9 -0
- fusion_bench_config/method/dop/dop.yaml +30 -0
- fusion_bench_config/method/dummy.yaml +6 -0
- fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
- fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
- fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
- fusion_bench_config/method/linear/weighted_average.yaml +3 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +1 -1
- fusion_bench_config/method/model_recombination.yaml +8 -0
- fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
- fusion_bench_config/method/opcm/opcm.yaml +5 -0
- fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
- fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
- fusion_bench_config/method/opcm/weight_average.yaml +5 -0
- fusion_bench_config/method/simple_average.yaml +9 -0
- fusion_bench_config/method/slerp/slerp.yaml +9 -0
- fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
- fusion_bench_config/method/task_arithmetic.yaml +9 -0
- fusion_bench_config/method/ties_merging.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -5,46 +5,156 @@
|
|
|
5
5
|
# ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
|
|
6
6
|
# ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
|
|
7
7
|
# flake8: noqa: F401
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
modelpool,
|
|
15
|
-
models,
|
|
16
|
-
optim,
|
|
17
|
-
programs,
|
|
18
|
-
taskpool,
|
|
19
|
-
tasks,
|
|
20
|
-
utils,
|
|
21
|
-
)
|
|
8
|
+
import sys
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
12
|
+
|
|
13
|
+
from . import constants, metrics, optim, tasks
|
|
22
14
|
from .constants import RuntimeConstants
|
|
23
|
-
from .method import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
15
|
+
from .method import _available_algorithms
|
|
16
|
+
|
|
17
|
+
_extra_objects = {
|
|
18
|
+
"RuntimeConstants": RuntimeConstants,
|
|
19
|
+
"constants": constants,
|
|
20
|
+
"metrics": metrics,
|
|
21
|
+
"optim": optim,
|
|
22
|
+
"tasks": tasks,
|
|
23
|
+
}
|
|
24
|
+
_import_structure = {
|
|
25
|
+
"dataset": ["CLIPDataset"],
|
|
26
|
+
"method": _available_algorithms,
|
|
27
|
+
"mixins": [
|
|
28
|
+
"CLIPClassificationMixin",
|
|
29
|
+
"FabricTrainingMixin",
|
|
30
|
+
"HydraConfigMixin",
|
|
31
|
+
"LightningFabricMixin",
|
|
32
|
+
"OpenCLIPClassificationMixin",
|
|
33
|
+
"PyinstrumentProfilerMixin",
|
|
34
|
+
"SimpleProfilerMixin",
|
|
35
|
+
"YAMLSerializationMixin",
|
|
36
|
+
"auto_register_config",
|
|
37
|
+
],
|
|
38
|
+
"modelpool": [
|
|
39
|
+
"AutoModelPool",
|
|
40
|
+
"BaseModelPool",
|
|
41
|
+
"CausalLMBackbonePool",
|
|
42
|
+
"CausalLMPool",
|
|
43
|
+
"CLIPVisionModelPool",
|
|
44
|
+
"GPT2ForSequenceClassificationPool",
|
|
45
|
+
"HuggingFaceGPT2ClassificationPool",
|
|
46
|
+
"NYUv2ModelPool",
|
|
47
|
+
"OpenCLIPVisionModelPool",
|
|
48
|
+
"PeftModelForSeq2SeqLMPool",
|
|
49
|
+
"ResNetForImageClassificationPool",
|
|
50
|
+
"Seq2SeqLMPool",
|
|
51
|
+
"SequenceClassificationModelPool",
|
|
52
|
+
],
|
|
53
|
+
"models": [
|
|
54
|
+
"create_default_model_card",
|
|
55
|
+
"load_model_card_template",
|
|
56
|
+
"save_pretrained_with_remote_code",
|
|
57
|
+
"separate_load",
|
|
58
|
+
"separate_save",
|
|
59
|
+
],
|
|
60
|
+
"programs": ["BaseHydraProgram", "FabricModelFusionProgram"],
|
|
61
|
+
"taskpool": [
|
|
62
|
+
"BaseTaskPool",
|
|
63
|
+
"CLIPVisionModelTaskPool",
|
|
64
|
+
"DummyTaskPool",
|
|
65
|
+
"GPT2TextClassificationTaskPool",
|
|
66
|
+
"LMEvalHarnessTaskPool",
|
|
67
|
+
"OpenCLIPVisionModelTaskPool",
|
|
68
|
+
"NYUv2TaskPool",
|
|
69
|
+
],
|
|
70
|
+
"utils": [
|
|
71
|
+
"ArithmeticStateDict",
|
|
72
|
+
"BoolStateDictType",
|
|
73
|
+
"LazyStateDict",
|
|
74
|
+
"StateDictType",
|
|
75
|
+
"TorchModelType",
|
|
76
|
+
"cache_with_joblib",
|
|
77
|
+
"get_rankzero_logger",
|
|
78
|
+
"import_object",
|
|
79
|
+
"instantiate",
|
|
80
|
+
"parse_dtype",
|
|
81
|
+
"print_parameters",
|
|
82
|
+
"seed_everything_by_time",
|
|
83
|
+
"set_default_cache_dir",
|
|
84
|
+
"set_print_function_call",
|
|
85
|
+
"set_print_function_call_permeanent",
|
|
86
|
+
"timeit_context",
|
|
87
|
+
],
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
if TYPE_CHECKING:
|
|
91
|
+
from .dataset import CLIPDataset
|
|
92
|
+
from .method import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
93
|
+
from .mixins import (
|
|
94
|
+
CLIPClassificationMixin,
|
|
95
|
+
FabricTrainingMixin,
|
|
96
|
+
HydraConfigMixin,
|
|
97
|
+
LightningFabricMixin,
|
|
98
|
+
OpenCLIPClassificationMixin,
|
|
99
|
+
PyinstrumentProfilerMixin,
|
|
100
|
+
SimpleProfilerMixin,
|
|
101
|
+
YAMLSerializationMixin,
|
|
102
|
+
auto_register_config,
|
|
103
|
+
)
|
|
104
|
+
from .modelpool import (
|
|
105
|
+
AutoModelPool,
|
|
106
|
+
BaseModelPool,
|
|
107
|
+
CausalLMBackbonePool,
|
|
108
|
+
CausalLMPool,
|
|
109
|
+
CLIPVisionModelPool,
|
|
110
|
+
GPT2ForSequenceClassificationPool,
|
|
111
|
+
HuggingFaceGPT2ClassificationPool,
|
|
112
|
+
NYUv2ModelPool,
|
|
113
|
+
OpenCLIPVisionModelPool,
|
|
114
|
+
PeftModelForSeq2SeqLMPool,
|
|
115
|
+
ResNetForImageClassificationPool,
|
|
116
|
+
Seq2SeqLMPool,
|
|
117
|
+
SequenceClassificationModelPool,
|
|
118
|
+
)
|
|
119
|
+
from .models import (
|
|
120
|
+
create_default_model_card,
|
|
121
|
+
load_model_card_template,
|
|
122
|
+
save_pretrained_with_remote_code,
|
|
123
|
+
separate_load,
|
|
124
|
+
separate_save,
|
|
125
|
+
)
|
|
126
|
+
from .programs import BaseHydraProgram, FabricModelFusionProgram
|
|
127
|
+
from .taskpool import (
|
|
128
|
+
BaseTaskPool,
|
|
129
|
+
CLIPVisionModelTaskPool,
|
|
130
|
+
DummyTaskPool,
|
|
131
|
+
GPT2TextClassificationTaskPool,
|
|
132
|
+
LMEvalHarnessTaskPool,
|
|
133
|
+
NYUv2TaskPool,
|
|
134
|
+
OpenCLIPVisionModelTaskPool,
|
|
135
|
+
)
|
|
136
|
+
from .utils import (
|
|
137
|
+
ArithmeticStateDict,
|
|
138
|
+
BoolStateDictType,
|
|
139
|
+
LazyStateDict,
|
|
140
|
+
StateDictType,
|
|
141
|
+
TorchModelType,
|
|
142
|
+
cache_with_joblib,
|
|
143
|
+
get_rankzero_logger,
|
|
144
|
+
import_object,
|
|
145
|
+
instantiate,
|
|
146
|
+
parse_dtype,
|
|
147
|
+
print_parameters,
|
|
148
|
+
seed_everything_by_time,
|
|
149
|
+
set_default_cache_dir,
|
|
150
|
+
set_print_function_call,
|
|
151
|
+
set_print_function_call_permeanent,
|
|
152
|
+
timeit_context,
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
sys.modules[__name__] = LazyImporter(
|
|
156
|
+
__name__,
|
|
157
|
+
globals()["__file__"],
|
|
158
|
+
_import_structure,
|
|
159
|
+
extra_objects=_extra_objects,
|
|
160
|
+
)
|
fusion_bench/dataset/__init__.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from omegaconf import DictConfig, open_dict
|
|
6
6
|
|
|
7
|
-
from .
|
|
7
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def load_dataset_from_config(dataset_config: DictConfig):
|
|
11
11
|
"""
|
|
12
12
|
Load the dataset from the configuration.
|
|
13
13
|
"""
|
|
14
|
+
from datasets import load_dataset
|
|
15
|
+
|
|
16
|
+
from fusion_bench.utils import instantiate
|
|
17
|
+
|
|
14
18
|
assert hasattr(dataset_config, "type"), "Dataset type not specified"
|
|
15
19
|
if dataset_config.type == "instantiate":
|
|
16
20
|
return instantiate(dataset_config.object)
|
|
@@ -27,3 +31,22 @@ def load_dataset_from_config(dataset_config: DictConfig):
|
|
|
27
31
|
return dataset
|
|
28
32
|
else:
|
|
29
33
|
raise ValueError(f"Unknown dataset type: {dataset_config.type}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_extra_objects = {
|
|
37
|
+
"load_dataset_from_config": load_dataset_from_config,
|
|
38
|
+
}
|
|
39
|
+
_import_structure = {
|
|
40
|
+
"clip_dataset": ["CLIPDataset"],
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from .clip_dataset import CLIPDataset
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
sys.modules[__name__] = LazyImporter(
|
|
48
|
+
__name__,
|
|
49
|
+
globals()["__file__"],
|
|
50
|
+
_import_structure,
|
|
51
|
+
extra_objects=_extra_objects,
|
|
52
|
+
)
|
|
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
|
-
from transformers import CLIPProcessor, ProcessorMixin
|
|
9
|
+
from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
|
|
10
10
|
|
|
11
11
|
__all__ = ["CLIPDataset"]
|
|
12
12
|
|
|
@@ -60,7 +60,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
60
60
|
raise ValueError("Each item should be a dictionary or a tuple of length 2")
|
|
61
61
|
image = item["image"]
|
|
62
62
|
if self.processor is not None:
|
|
63
|
-
if isinstance(self.processor, ProcessorMixin):
|
|
63
|
+
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
64
|
# Apply the processor to the image to get the input tensor
|
|
65
65
|
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
66
66
|
"pixel_values"
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
import sys
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
+
from fusion_bench.utils import join_lists
|
|
5
6
|
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
7
|
|
|
7
8
|
_import_structure = {
|
|
@@ -12,6 +13,8 @@ _import_structure = {
|
|
|
12
13
|
"classification": [
|
|
13
14
|
"ImageClassificationFineTuningForCLIP",
|
|
14
15
|
"ContinualImageClassificationFineTuningForCLIP",
|
|
16
|
+
"ImageClassificationFineTuning",
|
|
17
|
+
"ImageClassificationFineTuning_Test",
|
|
15
18
|
],
|
|
16
19
|
"lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
|
|
17
20
|
# analysis
|
|
@@ -67,6 +70,7 @@ _import_structure = {
|
|
|
67
70
|
"IsotropicMergingInCommonSubspace",
|
|
68
71
|
],
|
|
69
72
|
"opcm": ["OPCMForCLIP"],
|
|
73
|
+
"dop": ["ContinualDOPForCLIP"],
|
|
70
74
|
"gossip": [
|
|
71
75
|
"CLIPLayerWiseGossipAlgorithm",
|
|
72
76
|
"CLIPTaskWiseGossipAlgorithm",
|
|
@@ -131,7 +135,10 @@ _import_structure = {
|
|
|
131
135
|
"ProgressivePruningForMixtral",
|
|
132
136
|
],
|
|
133
137
|
}
|
|
134
|
-
|
|
138
|
+
_available_algorithms = join_lists(list(_import_structure.values()))
|
|
139
|
+
_extra_objects = {
|
|
140
|
+
"_available_algorithms": _available_algorithms,
|
|
141
|
+
}
|
|
135
142
|
|
|
136
143
|
if TYPE_CHECKING:
|
|
137
144
|
from .ada_svd import AdaSVDMergingForCLIPVisionModel
|
|
@@ -141,6 +148,8 @@ if TYPE_CHECKING:
|
|
|
141
148
|
from .bitdelta import BitDeltaAlgorithm
|
|
142
149
|
from .classification import (
|
|
143
150
|
ContinualImageClassificationFineTuningForCLIP,
|
|
151
|
+
ImageClassificationFineTuning,
|
|
152
|
+
ImageClassificationFineTuning_Test,
|
|
144
153
|
ImageClassificationFineTuningForCLIP,
|
|
145
154
|
)
|
|
146
155
|
from .concrete_subspace import (
|
|
@@ -204,6 +213,7 @@ if TYPE_CHECKING:
|
|
|
204
213
|
from .model_recombination import ModelRecombinationAlgorithm
|
|
205
214
|
from .model_stock import ModelStock
|
|
206
215
|
from .opcm import OPCMForCLIP
|
|
216
|
+
from .dop import ContinualDOPForCLIP
|
|
207
217
|
from .pruning import (
|
|
208
218
|
MagnitudeDiffPruningAlgorithm,
|
|
209
219
|
MagnitudePruningForLlama,
|
|
@@ -252,4 +262,5 @@ else:
|
|
|
252
262
|
__name__,
|
|
253
263
|
globals()["__file__"],
|
|
254
264
|
_import_structure,
|
|
265
|
+
extra_objects=_extra_objects,
|
|
255
266
|
)
|
|
@@ -1,3 +1,28 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
|
+
|
|
7
|
+
_import_structure = {
|
|
8
|
+
"clip_finetune": ["ImageClassificationFineTuningForCLIP"],
|
|
9
|
+
"continual_clip_finetune": ["ContinualImageClassificationFineTuningForCLIP"],
|
|
10
|
+
"image_classification_finetune": [
|
|
11
|
+
"ImageClassificationFineTuning",
|
|
12
|
+
"ImageClassificationFineTuning_Test",
|
|
13
|
+
],
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .clip_finetune import ImageClassificationFineTuningForCLIP
|
|
18
|
+
from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
|
|
19
|
+
from .image_classification_finetune import (
|
|
20
|
+
ImageClassificationFineTuning,
|
|
21
|
+
ImageClassificationFineTuning_Test,
|
|
22
|
+
)
|
|
23
|
+
else:
|
|
24
|
+
sys.modules[__name__] = LazyImporter(
|
|
25
|
+
__name__,
|
|
26
|
+
globals()["__file__"],
|
|
27
|
+
_import_structure,
|
|
28
|
+
)
|
|
@@ -5,8 +5,8 @@ Fine-tune CLIP-ViT-B/32:
|
|
|
5
5
|
|
|
6
6
|
```bash
|
|
7
7
|
fusion_bench \
|
|
8
|
-
method=clip_finetune \
|
|
9
|
-
modelpool=clip-vit-base-patch32_mtl \
|
|
8
|
+
method=classification/clip_finetune \
|
|
9
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_mtl \
|
|
10
10
|
taskpool=dummy
|
|
11
11
|
```
|
|
12
12
|
|
|
@@ -15,12 +15,14 @@ Fine-tune CLIP-ViT-L/14 on eight GPUs with a per-device per-task batch size of 2
|
|
|
15
15
|
```bash
|
|
16
16
|
fusion_bench \
|
|
17
17
|
fabric.devices=8 \
|
|
18
|
-
method=clip_finetune \
|
|
18
|
+
method=classification/clip_finetune \
|
|
19
19
|
method.batch_size=2 \
|
|
20
|
-
modelpool=clip-vit-base-patch32_mtl \
|
|
20
|
+
modelpool=CLIPVisionModelPool/clip-vit-base-patch32_mtl \
|
|
21
21
|
modelpool.models.0.path=openai/clip-vit-large-patch14 \
|
|
22
22
|
taskpool=dummy
|
|
23
23
|
```
|
|
24
|
+
|
|
25
|
+
See `examples/clip_finetune` for more details.
|
|
24
26
|
"""
|
|
25
27
|
|
|
26
28
|
import os
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import lightning as L
|
|
5
|
+
import lightning.pytorch.callbacks as pl_callbacks
|
|
6
|
+
import torch
|
|
7
|
+
from lightning.pytorch.loggers import TensorBoardLogger
|
|
8
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
9
|
+
from lit_learn.lit_modules import ERM_LitModule
|
|
10
|
+
from omegaconf import DictConfig
|
|
11
|
+
from torch import nn
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from torchmetrics.classification import Accuracy
|
|
14
|
+
|
|
15
|
+
from fusion_bench import (
|
|
16
|
+
BaseAlgorithm,
|
|
17
|
+
BaseModelPool,
|
|
18
|
+
RuntimeConstants,
|
|
19
|
+
auto_register_config,
|
|
20
|
+
get_rankzero_logger,
|
|
21
|
+
instantiate,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.dataset import CLIPDataset
|
|
24
|
+
from fusion_bench.modelpool import ResNetForImageClassificationPool
|
|
25
|
+
from fusion_bench.tasks.clip_classification import get_num_classes
|
|
26
|
+
|
|
27
|
+
log = get_rankzero_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@auto_register_config
|
|
31
|
+
class ImageClassificationFineTuning(BaseAlgorithm):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
max_epochs: Optional[int],
|
|
35
|
+
max_steps: Optional[int],
|
|
36
|
+
label_smoothing: float,
|
|
37
|
+
optimizer: DictConfig,
|
|
38
|
+
lr_scheduler: DictConfig,
|
|
39
|
+
dataloader_kwargs: DictConfig,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
assert (max_epochs is None) or (
|
|
44
|
+
max_steps is None or max_steps < 0
|
|
45
|
+
), "Only one of max_epochs or max_steps should be set."
|
|
46
|
+
self.training_interval = "epoch" if max_epochs is not None else "step"
|
|
47
|
+
if self.training_interval == "epoch":
|
|
48
|
+
self.max_steps = -1
|
|
49
|
+
log.info(f"Training interval: {self.training_interval}")
|
|
50
|
+
log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
|
|
51
|
+
|
|
52
|
+
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
53
|
+
# load model and dataset
|
|
54
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
55
|
+
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
56
|
+
|
|
57
|
+
assert (
|
|
58
|
+
len(modelpool.train_dataset_names) == 1
|
|
59
|
+
), "Exactly one training dataset is required."
|
|
60
|
+
self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
|
|
61
|
+
num_classes = get_num_classes(dataset_name)
|
|
62
|
+
train_dataset = modelpool.load_train_dataset(dataset_name)
|
|
63
|
+
train_dataset = CLIPDataset(
|
|
64
|
+
train_dataset, processor=modelpool.load_processor(stage="train")
|
|
65
|
+
)
|
|
66
|
+
train_loader = self.get_dataloader(train_dataset, stage="train")
|
|
67
|
+
if modelpool.has_val_dataset:
|
|
68
|
+
val_dataset = modelpool.load_val_dataset(dataset_name)
|
|
69
|
+
val_dataset = CLIPDataset(
|
|
70
|
+
val_dataset, processor=modelpool.load_processor(stage="val")
|
|
71
|
+
)
|
|
72
|
+
val_loader = self.get_dataloader(val_dataset, stage="val")
|
|
73
|
+
|
|
74
|
+
# configure optimizer
|
|
75
|
+
optimizer = instantiate(self.optimizer, params=model.parameters())
|
|
76
|
+
if self.lr_scheduler is not None:
|
|
77
|
+
lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
|
|
78
|
+
optimizer = {
|
|
79
|
+
"optimizer": optimizer,
|
|
80
|
+
"lr_scheduler": {
|
|
81
|
+
"scheduler": lr_scheduler,
|
|
82
|
+
"interval": self.training_interval,
|
|
83
|
+
"frequency": 1,
|
|
84
|
+
},
|
|
85
|
+
}
|
|
86
|
+
log.info(f"optimizer:\n{optimizer}")
|
|
87
|
+
|
|
88
|
+
lit_module = ERM_LitModule(
|
|
89
|
+
model,
|
|
90
|
+
optimizer,
|
|
91
|
+
objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
|
|
92
|
+
metrics={
|
|
93
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
94
|
+
"acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
log_dir = (
|
|
99
|
+
self._program.path.log_dir
|
|
100
|
+
if self._program is not None
|
|
101
|
+
else "outputs/lightning_logs"
|
|
102
|
+
)
|
|
103
|
+
trainer = L.Trainer(
|
|
104
|
+
max_epochs=self.max_epochs,
|
|
105
|
+
max_steps=self.max_steps,
|
|
106
|
+
accelerator="auto",
|
|
107
|
+
devices="auto",
|
|
108
|
+
callbacks=[
|
|
109
|
+
pl_callbacks.LearningRateMonitor(logging_interval="step"),
|
|
110
|
+
pl_callbacks.DeviceStatsMonitor(),
|
|
111
|
+
],
|
|
112
|
+
logger=TensorBoardLogger(
|
|
113
|
+
save_dir=log_dir,
|
|
114
|
+
name="",
|
|
115
|
+
),
|
|
116
|
+
fast_dev_run=RuntimeConstants.debug,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
trainer.fit(
|
|
120
|
+
lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
|
|
121
|
+
)
|
|
122
|
+
model = lit_module.model
|
|
123
|
+
if rank_zero_only.rank == 0:
|
|
124
|
+
log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
|
|
125
|
+
modelpool.save_model(
|
|
126
|
+
model,
|
|
127
|
+
path=os.path.join(
|
|
128
|
+
trainer.log_dir if trainer.log_dir is not None else log_dir,
|
|
129
|
+
"raw_checkpoints",
|
|
130
|
+
"final",
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
return model
|
|
134
|
+
|
|
135
|
+
def get_dataloader(self, dataset, stage: str):
|
|
136
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
137
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
138
|
+
if "shuffle" not in dataloader_kwargs:
|
|
139
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
140
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@auto_register_config
|
|
144
|
+
class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
145
|
+
def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
|
|
146
|
+
super().__init__(**kwargs)
|
|
147
|
+
|
|
148
|
+
def run(self, modelpool: BaseModelPool):
|
|
149
|
+
assert (
|
|
150
|
+
modelpool.has_val_dataset or modelpool.has_test_dataset
|
|
151
|
+
), "No validation or test dataset found in the model pool."
|
|
152
|
+
|
|
153
|
+
# load model and dataset
|
|
154
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
155
|
+
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
156
|
+
|
|
157
|
+
if modelpool.has_test_dataset:
|
|
158
|
+
assert (
|
|
159
|
+
len(modelpool.test_dataset_names) == 1
|
|
160
|
+
), "Exactly one test dataset is required."
|
|
161
|
+
self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
|
|
162
|
+
dataset = modelpool.load_test_dataset(dataset_name)
|
|
163
|
+
dataset = CLIPDataset(
|
|
164
|
+
dataset, processor=modelpool.load_processor(stage="test")
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
assert (
|
|
168
|
+
len(modelpool.val_dataset_names) == 1
|
|
169
|
+
), "Exactly one validation dataset is required."
|
|
170
|
+
self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
|
|
171
|
+
dataset = modelpool.load_val_dataset(dataset_name)
|
|
172
|
+
dataset = CLIPDataset(
|
|
173
|
+
dataset, processor=modelpool.load_processor(stage="test")
|
|
174
|
+
)
|
|
175
|
+
num_classes = get_num_classes(dataset_name)
|
|
176
|
+
|
|
177
|
+
test_loader = self.get_dataloader(dataset, stage="test")
|
|
178
|
+
|
|
179
|
+
if self.checkpoint_path is None:
|
|
180
|
+
lit_module = ERM_LitModule(
|
|
181
|
+
model,
|
|
182
|
+
metrics={
|
|
183
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
184
|
+
"acc@5": Accuracy(
|
|
185
|
+
task="multiclass", num_classes=num_classes, top_k=5
|
|
186
|
+
),
|
|
187
|
+
},
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
lit_module = ERM_LitModule.load_from_checkpoint(
|
|
191
|
+
checkpoint_path=self.checkpoint_path,
|
|
192
|
+
model=model,
|
|
193
|
+
metrics={
|
|
194
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
195
|
+
"acc@5": Accuracy(
|
|
196
|
+
task="multiclass", num_classes=num_classes, top_k=5
|
|
197
|
+
),
|
|
198
|
+
},
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
trainer = L.Trainer(
|
|
202
|
+
devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
test_metrics = trainer.test(lit_module, dataloaders=test_loader)
|
|
206
|
+
log.info(f"Test metrics: {test_metrics}")
|
|
207
|
+
return model
|
|
208
|
+
|
|
209
|
+
def get_dataloader(self, dataset, stage: str):
|
|
210
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
211
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
212
|
+
if "shuffle" not in dataloader_kwargs:
|
|
213
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
214
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .dop import ContinualDOPForCLIP
|