fusion-bench 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +10 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +55 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +1 -0
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +1 -1
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +48 -34
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -5,46 +5,156 @@
|
|
|
5
5
|
# ██║ ╚██████╔╝███████║██║╚██████╔╝██║ ╚████║ ██████╔╝███████╗██║ ╚████║╚██████╗██║ ██║
|
|
6
6
|
# ╚═╝ ╚═════╝ ╚══════╝╚═╝ ╚═════╝ ╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═══╝ ╚═════╝╚═╝ ╚═╝
|
|
7
7
|
# flake8: noqa: F401
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
modelpool,
|
|
15
|
-
models,
|
|
16
|
-
optim,
|
|
17
|
-
programs,
|
|
18
|
-
taskpool,
|
|
19
|
-
tasks,
|
|
20
|
-
utils,
|
|
21
|
-
)
|
|
8
|
+
import sys
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
12
|
+
|
|
13
|
+
from . import constants, metrics, optim, tasks
|
|
22
14
|
from .constants import RuntimeConstants
|
|
23
|
-
from .method import
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
15
|
+
from .method import _available_algorithms
|
|
16
|
+
|
|
17
|
+
_extra_objects = {
|
|
18
|
+
"RuntimeConstants": RuntimeConstants,
|
|
19
|
+
"constants": constants,
|
|
20
|
+
"metrics": metrics,
|
|
21
|
+
"optim": optim,
|
|
22
|
+
"tasks": tasks,
|
|
23
|
+
}
|
|
24
|
+
_import_structure = {
|
|
25
|
+
"dataset": ["CLIPDataset"],
|
|
26
|
+
"method": _available_algorithms,
|
|
27
|
+
"mixins": [
|
|
28
|
+
"CLIPClassificationMixin",
|
|
29
|
+
"FabricTrainingMixin",
|
|
30
|
+
"HydraConfigMixin",
|
|
31
|
+
"LightningFabricMixin",
|
|
32
|
+
"OpenCLIPClassificationMixin",
|
|
33
|
+
"PyinstrumentProfilerMixin",
|
|
34
|
+
"SimpleProfilerMixin",
|
|
35
|
+
"YAMLSerializationMixin",
|
|
36
|
+
"auto_register_config",
|
|
37
|
+
],
|
|
38
|
+
"modelpool": [
|
|
39
|
+
"AutoModelPool",
|
|
40
|
+
"BaseModelPool",
|
|
41
|
+
"CausalLMBackbonePool",
|
|
42
|
+
"CausalLMPool",
|
|
43
|
+
"CLIPVisionModelPool",
|
|
44
|
+
"GPT2ForSequenceClassificationPool",
|
|
45
|
+
"HuggingFaceGPT2ClassificationPool",
|
|
46
|
+
"NYUv2ModelPool",
|
|
47
|
+
"OpenCLIPVisionModelPool",
|
|
48
|
+
"PeftModelForSeq2SeqLMPool",
|
|
49
|
+
"ResNetForImageClassificationPool",
|
|
50
|
+
"Seq2SeqLMPool",
|
|
51
|
+
"SequenceClassificationModelPool",
|
|
52
|
+
],
|
|
53
|
+
"models": [
|
|
54
|
+
"create_default_model_card",
|
|
55
|
+
"load_model_card_template",
|
|
56
|
+
"save_pretrained_with_remote_code",
|
|
57
|
+
"separate_load",
|
|
58
|
+
"separate_save",
|
|
59
|
+
],
|
|
60
|
+
"programs": ["BaseHydraProgram", "FabricModelFusionProgram"],
|
|
61
|
+
"taskpool": [
|
|
62
|
+
"BaseTaskPool",
|
|
63
|
+
"CLIPVisionModelTaskPool",
|
|
64
|
+
"DummyTaskPool",
|
|
65
|
+
"GPT2TextClassificationTaskPool",
|
|
66
|
+
"LMEvalHarnessTaskPool",
|
|
67
|
+
"OpenCLIPVisionModelTaskPool",
|
|
68
|
+
"NYUv2TaskPool",
|
|
69
|
+
],
|
|
70
|
+
"utils": [
|
|
71
|
+
"ArithmeticStateDict",
|
|
72
|
+
"BoolStateDictType",
|
|
73
|
+
"LazyStateDict",
|
|
74
|
+
"StateDictType",
|
|
75
|
+
"TorchModelType",
|
|
76
|
+
"cache_with_joblib",
|
|
77
|
+
"get_rankzero_logger",
|
|
78
|
+
"import_object",
|
|
79
|
+
"instantiate",
|
|
80
|
+
"parse_dtype",
|
|
81
|
+
"print_parameters",
|
|
82
|
+
"seed_everything_by_time",
|
|
83
|
+
"set_default_cache_dir",
|
|
84
|
+
"set_print_function_call",
|
|
85
|
+
"set_print_function_call_permeanent",
|
|
86
|
+
"timeit_context",
|
|
87
|
+
],
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
if TYPE_CHECKING:
|
|
91
|
+
from .dataset import CLIPDataset
|
|
92
|
+
from .method import BaseAlgorithm, BaseModelFusionAlgorithm
|
|
93
|
+
from .mixins import (
|
|
94
|
+
CLIPClassificationMixin,
|
|
95
|
+
FabricTrainingMixin,
|
|
96
|
+
HydraConfigMixin,
|
|
97
|
+
LightningFabricMixin,
|
|
98
|
+
OpenCLIPClassificationMixin,
|
|
99
|
+
PyinstrumentProfilerMixin,
|
|
100
|
+
SimpleProfilerMixin,
|
|
101
|
+
YAMLSerializationMixin,
|
|
102
|
+
auto_register_config,
|
|
103
|
+
)
|
|
104
|
+
from .modelpool import (
|
|
105
|
+
AutoModelPool,
|
|
106
|
+
BaseModelPool,
|
|
107
|
+
CausalLMBackbonePool,
|
|
108
|
+
CausalLMPool,
|
|
109
|
+
CLIPVisionModelPool,
|
|
110
|
+
GPT2ForSequenceClassificationPool,
|
|
111
|
+
HuggingFaceGPT2ClassificationPool,
|
|
112
|
+
NYUv2ModelPool,
|
|
113
|
+
OpenCLIPVisionModelPool,
|
|
114
|
+
PeftModelForSeq2SeqLMPool,
|
|
115
|
+
ResNetForImageClassificationPool,
|
|
116
|
+
Seq2SeqLMPool,
|
|
117
|
+
SequenceClassificationModelPool,
|
|
118
|
+
)
|
|
119
|
+
from .models import (
|
|
120
|
+
create_default_model_card,
|
|
121
|
+
load_model_card_template,
|
|
122
|
+
save_pretrained_with_remote_code,
|
|
123
|
+
separate_load,
|
|
124
|
+
separate_save,
|
|
125
|
+
)
|
|
126
|
+
from .programs import BaseHydraProgram, FabricModelFusionProgram
|
|
127
|
+
from .taskpool import (
|
|
128
|
+
BaseTaskPool,
|
|
129
|
+
CLIPVisionModelTaskPool,
|
|
130
|
+
DummyTaskPool,
|
|
131
|
+
GPT2TextClassificationTaskPool,
|
|
132
|
+
LMEvalHarnessTaskPool,
|
|
133
|
+
NYUv2TaskPool,
|
|
134
|
+
OpenCLIPVisionModelTaskPool,
|
|
135
|
+
)
|
|
136
|
+
from .utils import (
|
|
137
|
+
ArithmeticStateDict,
|
|
138
|
+
BoolStateDictType,
|
|
139
|
+
LazyStateDict,
|
|
140
|
+
StateDictType,
|
|
141
|
+
TorchModelType,
|
|
142
|
+
cache_with_joblib,
|
|
143
|
+
get_rankzero_logger,
|
|
144
|
+
import_object,
|
|
145
|
+
instantiate,
|
|
146
|
+
parse_dtype,
|
|
147
|
+
print_parameters,
|
|
148
|
+
seed_everything_by_time,
|
|
149
|
+
set_default_cache_dir,
|
|
150
|
+
set_print_function_call,
|
|
151
|
+
set_print_function_call_permeanent,
|
|
152
|
+
timeit_context,
|
|
153
|
+
)
|
|
154
|
+
else:
|
|
155
|
+
sys.modules[__name__] = LazyImporter(
|
|
156
|
+
__name__,
|
|
157
|
+
globals()["__file__"],
|
|
158
|
+
_import_structure,
|
|
159
|
+
extra_objects=_extra_objects,
|
|
160
|
+
)
|
fusion_bench/dataset/__init__.py
CHANGED
|
@@ -1,16 +1,20 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
-
from
|
|
5
|
+
from omegaconf import DictConfig, open_dict
|
|
6
6
|
|
|
7
|
-
from .
|
|
7
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
8
8
|
|
|
9
9
|
|
|
10
10
|
def load_dataset_from_config(dataset_config: DictConfig):
|
|
11
11
|
"""
|
|
12
12
|
Load the dataset from the configuration.
|
|
13
13
|
"""
|
|
14
|
+
from datasets import load_dataset
|
|
15
|
+
|
|
16
|
+
from fusion_bench.utils import instantiate
|
|
17
|
+
|
|
14
18
|
assert hasattr(dataset_config, "type"), "Dataset type not specified"
|
|
15
19
|
if dataset_config.type == "instantiate":
|
|
16
20
|
return instantiate(dataset_config.object)
|
|
@@ -27,3 +31,22 @@ def load_dataset_from_config(dataset_config: DictConfig):
|
|
|
27
31
|
return dataset
|
|
28
32
|
else:
|
|
29
33
|
raise ValueError(f"Unknown dataset type: {dataset_config.type}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
_extra_objects = {
|
|
37
|
+
"load_dataset_from_config": load_dataset_from_config,
|
|
38
|
+
}
|
|
39
|
+
_import_structure = {
|
|
40
|
+
"clip_dataset": ["CLIPDataset"],
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
if TYPE_CHECKING:
|
|
44
|
+
from .clip_dataset import CLIPDataset
|
|
45
|
+
|
|
46
|
+
else:
|
|
47
|
+
sys.modules[__name__] = LazyImporter(
|
|
48
|
+
__name__,
|
|
49
|
+
globals()["__file__"],
|
|
50
|
+
_import_structure,
|
|
51
|
+
extra_objects=_extra_objects,
|
|
52
|
+
)
|
|
@@ -6,7 +6,7 @@ from typing import Optional, Tuple
|
|
|
6
6
|
|
|
7
7
|
import torch
|
|
8
8
|
from torch.utils.data import Dataset
|
|
9
|
-
from transformers import CLIPProcessor, ProcessorMixin
|
|
9
|
+
from transformers import BaseImageProcessor, CLIPProcessor, ProcessorMixin
|
|
10
10
|
|
|
11
11
|
__all__ = ["CLIPDataset"]
|
|
12
12
|
|
|
@@ -60,7 +60,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
60
60
|
raise ValueError("Each item should be a dictionary or a tuple of length 2")
|
|
61
61
|
image = item["image"]
|
|
62
62
|
if self.processor is not None:
|
|
63
|
-
if isinstance(self.processor, ProcessorMixin):
|
|
63
|
+
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
64
|
# Apply the processor to the image to get the input tensor
|
|
65
65
|
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
66
66
|
"pixel_values"
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
import sys
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
|
+
from fusion_bench.utils import join_lists
|
|
5
6
|
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
7
|
|
|
7
8
|
_import_structure = {
|
|
@@ -12,6 +13,8 @@ _import_structure = {
|
|
|
12
13
|
"classification": [
|
|
13
14
|
"ImageClassificationFineTuningForCLIP",
|
|
14
15
|
"ContinualImageClassificationFineTuningForCLIP",
|
|
16
|
+
"ImageClassificationFineTuning",
|
|
17
|
+
"ImageClassificationFineTuning_Test",
|
|
15
18
|
],
|
|
16
19
|
"lm_finetune": ["FullFinetuneSFT", "PeftFinetuneSFT", "BradleyTerryRewardModeling"],
|
|
17
20
|
# analysis
|
|
@@ -131,7 +134,10 @@ _import_structure = {
|
|
|
131
134
|
"ProgressivePruningForMixtral",
|
|
132
135
|
],
|
|
133
136
|
}
|
|
134
|
-
|
|
137
|
+
_available_algorithms = join_lists(list(_import_structure.values()))
|
|
138
|
+
_extra_objects = {
|
|
139
|
+
"_available_algorithms": _available_algorithms,
|
|
140
|
+
}
|
|
135
141
|
|
|
136
142
|
if TYPE_CHECKING:
|
|
137
143
|
from .ada_svd import AdaSVDMergingForCLIPVisionModel
|
|
@@ -141,6 +147,8 @@ if TYPE_CHECKING:
|
|
|
141
147
|
from .bitdelta import BitDeltaAlgorithm
|
|
142
148
|
from .classification import (
|
|
143
149
|
ContinualImageClassificationFineTuningForCLIP,
|
|
150
|
+
ImageClassificationFineTuning,
|
|
151
|
+
ImageClassificationFineTuning_Test,
|
|
144
152
|
ImageClassificationFineTuningForCLIP,
|
|
145
153
|
)
|
|
146
154
|
from .concrete_subspace import (
|
|
@@ -252,4 +260,5 @@ else:
|
|
|
252
260
|
__name__,
|
|
253
261
|
globals()["__file__"],
|
|
254
262
|
_import_structure,
|
|
263
|
+
extra_objects=_extra_objects,
|
|
255
264
|
)
|
|
@@ -1,3 +1,28 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
from
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
|
+
|
|
7
|
+
_import_structure = {
|
|
8
|
+
"clip_finetune": ["ImageClassificationFineTuningForCLIP"],
|
|
9
|
+
"continual_clip_finetune": ["ContinualImageClassificationFineTuningForCLIP"],
|
|
10
|
+
"image_classification_finetune": [
|
|
11
|
+
"ImageClassificationFineTuning",
|
|
12
|
+
"ImageClassificationFineTuning_Test",
|
|
13
|
+
],
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .clip_finetune import ImageClassificationFineTuningForCLIP
|
|
18
|
+
from .continual_clip_finetune import ContinualImageClassificationFineTuningForCLIP
|
|
19
|
+
from .image_classification_finetune import (
|
|
20
|
+
ImageClassificationFineTuning,
|
|
21
|
+
ImageClassificationFineTuning_Test,
|
|
22
|
+
)
|
|
23
|
+
else:
|
|
24
|
+
sys.modules[__name__] = LazyImporter(
|
|
25
|
+
__name__,
|
|
26
|
+
globals()["__file__"],
|
|
27
|
+
_import_structure,
|
|
28
|
+
)
|
|
@@ -0,0 +1,214 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
import lightning as L
|
|
5
|
+
import lightning.pytorch.callbacks as pl_callbacks
|
|
6
|
+
import torch
|
|
7
|
+
from lightning.pytorch.loggers import TensorBoardLogger
|
|
8
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
9
|
+
from lit_learn.lit_modules import ERM_LitModule
|
|
10
|
+
from omegaconf import DictConfig
|
|
11
|
+
from torch import nn
|
|
12
|
+
from torch.utils.data import DataLoader
|
|
13
|
+
from torchmetrics.classification import Accuracy
|
|
14
|
+
|
|
15
|
+
from fusion_bench import (
|
|
16
|
+
BaseAlgorithm,
|
|
17
|
+
BaseModelPool,
|
|
18
|
+
RuntimeConstants,
|
|
19
|
+
auto_register_config,
|
|
20
|
+
get_rankzero_logger,
|
|
21
|
+
instantiate,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.dataset import CLIPDataset
|
|
24
|
+
from fusion_bench.modelpool import ResNetForImageClassificationPool
|
|
25
|
+
from fusion_bench.tasks.clip_classification import get_num_classes
|
|
26
|
+
|
|
27
|
+
log = get_rankzero_logger(__name__)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@auto_register_config
|
|
31
|
+
class ImageClassificationFineTuning(BaseAlgorithm):
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
max_epochs: Optional[int],
|
|
35
|
+
max_steps: Optional[int],
|
|
36
|
+
label_smoothing: float,
|
|
37
|
+
optimizer: DictConfig,
|
|
38
|
+
lr_scheduler: DictConfig,
|
|
39
|
+
dataloader_kwargs: DictConfig,
|
|
40
|
+
**kwargs,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(**kwargs)
|
|
43
|
+
assert (max_epochs is None) or (
|
|
44
|
+
max_steps is None or max_steps < 0
|
|
45
|
+
), "Only one of max_epochs or max_steps should be set."
|
|
46
|
+
self.training_interval = "epoch" if max_epochs is not None else "step"
|
|
47
|
+
if self.training_interval == "epoch":
|
|
48
|
+
self.max_steps = -1
|
|
49
|
+
log.info(f"Training interval: {self.training_interval}")
|
|
50
|
+
log.info(f"Max epochs: {max_epochs}, max steps: {max_steps}")
|
|
51
|
+
|
|
52
|
+
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
53
|
+
# load model and dataset
|
|
54
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
55
|
+
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
56
|
+
|
|
57
|
+
assert (
|
|
58
|
+
len(modelpool.train_dataset_names) == 1
|
|
59
|
+
), "Exactly one training dataset is required."
|
|
60
|
+
self.dataset_name = dataset_name = modelpool.train_dataset_names[0]
|
|
61
|
+
num_classes = get_num_classes(dataset_name)
|
|
62
|
+
train_dataset = modelpool.load_train_dataset(dataset_name)
|
|
63
|
+
train_dataset = CLIPDataset(
|
|
64
|
+
train_dataset, processor=modelpool.load_processor(stage="train")
|
|
65
|
+
)
|
|
66
|
+
train_loader = self.get_dataloader(train_dataset, stage="train")
|
|
67
|
+
if modelpool.has_val_dataset:
|
|
68
|
+
val_dataset = modelpool.load_val_dataset(dataset_name)
|
|
69
|
+
val_dataset = CLIPDataset(
|
|
70
|
+
val_dataset, processor=modelpool.load_processor(stage="val")
|
|
71
|
+
)
|
|
72
|
+
val_loader = self.get_dataloader(val_dataset, stage="val")
|
|
73
|
+
|
|
74
|
+
# configure optimizer
|
|
75
|
+
optimizer = instantiate(self.optimizer, params=model.parameters())
|
|
76
|
+
if self.lr_scheduler is not None:
|
|
77
|
+
lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
|
|
78
|
+
optimizer = {
|
|
79
|
+
"optimizer": optimizer,
|
|
80
|
+
"lr_scheduler": {
|
|
81
|
+
"scheduler": lr_scheduler,
|
|
82
|
+
"interval": self.training_interval,
|
|
83
|
+
"frequency": 1,
|
|
84
|
+
},
|
|
85
|
+
}
|
|
86
|
+
log.info(f"optimizer:\n{optimizer}")
|
|
87
|
+
|
|
88
|
+
lit_module = ERM_LitModule(
|
|
89
|
+
model,
|
|
90
|
+
optimizer,
|
|
91
|
+
objective=nn.CrossEntropyLoss(label_smoothing=self.label_smoothing),
|
|
92
|
+
metrics={
|
|
93
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
94
|
+
"acc@5": Accuracy(task="multiclass", num_classes=num_classes, top_k=5),
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
log_dir = (
|
|
99
|
+
self._program.path.log_dir
|
|
100
|
+
if self._program is not None
|
|
101
|
+
else "outputs/lightning_logs"
|
|
102
|
+
)
|
|
103
|
+
trainer = L.Trainer(
|
|
104
|
+
max_epochs=self.max_epochs,
|
|
105
|
+
max_steps=self.max_steps,
|
|
106
|
+
accelerator="auto",
|
|
107
|
+
devices="auto",
|
|
108
|
+
callbacks=[
|
|
109
|
+
pl_callbacks.LearningRateMonitor(logging_interval="step"),
|
|
110
|
+
pl_callbacks.DeviceStatsMonitor(),
|
|
111
|
+
],
|
|
112
|
+
logger=TensorBoardLogger(
|
|
113
|
+
save_dir=log_dir,
|
|
114
|
+
name="",
|
|
115
|
+
),
|
|
116
|
+
fast_dev_run=RuntimeConstants.debug,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
trainer.fit(
|
|
120
|
+
lit_module, train_dataloaders=train_loader, val_dataloaders=val_loader
|
|
121
|
+
)
|
|
122
|
+
model = lit_module.model
|
|
123
|
+
if rank_zero_only.rank == 0:
|
|
124
|
+
log.info(f"Saving the final model to {log_dir}/raw_checkpoints/final")
|
|
125
|
+
modelpool.save_model(
|
|
126
|
+
model,
|
|
127
|
+
path=os.path.join(
|
|
128
|
+
trainer.log_dir if trainer.log_dir is not None else log_dir,
|
|
129
|
+
"raw_checkpoints",
|
|
130
|
+
"final",
|
|
131
|
+
),
|
|
132
|
+
)
|
|
133
|
+
return model
|
|
134
|
+
|
|
135
|
+
def get_dataloader(self, dataset, stage: str):
|
|
136
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
137
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
138
|
+
if "shuffle" not in dataloader_kwargs:
|
|
139
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
140
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
@auto_register_config
|
|
144
|
+
class ImageClassificationFineTuning_Test(BaseAlgorithm):
|
|
145
|
+
def __init__(self, checkpoint_path: str, dataloader_kwargs: DictConfig, **kwargs):
|
|
146
|
+
super().__init__(**kwargs)
|
|
147
|
+
|
|
148
|
+
def run(self, modelpool: BaseModelPool):
|
|
149
|
+
assert (
|
|
150
|
+
modelpool.has_val_dataset or modelpool.has_test_dataset
|
|
151
|
+
), "No validation or test dataset found in the model pool."
|
|
152
|
+
|
|
153
|
+
# load model and dataset
|
|
154
|
+
model = modelpool.load_pretrained_or_first_model()
|
|
155
|
+
assert isinstance(model, nn.Module), "Loaded model is not a nn.Module."
|
|
156
|
+
|
|
157
|
+
if modelpool.has_test_dataset:
|
|
158
|
+
assert (
|
|
159
|
+
len(modelpool.test_dataset_names) == 1
|
|
160
|
+
), "Exactly one test dataset is required."
|
|
161
|
+
self.dataset_name = dataset_name = modelpool.test_dataset_names[0]
|
|
162
|
+
dataset = modelpool.load_test_dataset(dataset_name)
|
|
163
|
+
dataset = CLIPDataset(
|
|
164
|
+
dataset, processor=modelpool.load_processor(stage="test")
|
|
165
|
+
)
|
|
166
|
+
else:
|
|
167
|
+
assert (
|
|
168
|
+
len(modelpool.val_dataset_names) == 1
|
|
169
|
+
), "Exactly one validation dataset is required."
|
|
170
|
+
self.dataset_name = dataset_name = modelpool.val_dataset_names[0]
|
|
171
|
+
dataset = modelpool.load_val_dataset(dataset_name)
|
|
172
|
+
dataset = CLIPDataset(
|
|
173
|
+
dataset, processor=modelpool.load_processor(stage="test")
|
|
174
|
+
)
|
|
175
|
+
num_classes = get_num_classes(dataset_name)
|
|
176
|
+
|
|
177
|
+
test_loader = self.get_dataloader(dataset, stage="test")
|
|
178
|
+
|
|
179
|
+
if self.checkpoint_path is None:
|
|
180
|
+
lit_module = ERM_LitModule(
|
|
181
|
+
model,
|
|
182
|
+
metrics={
|
|
183
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
184
|
+
"acc@5": Accuracy(
|
|
185
|
+
task="multiclass", num_classes=num_classes, top_k=5
|
|
186
|
+
),
|
|
187
|
+
},
|
|
188
|
+
)
|
|
189
|
+
else:
|
|
190
|
+
lit_module = ERM_LitModule.load_from_checkpoint(
|
|
191
|
+
checkpoint_path=self.checkpoint_path,
|
|
192
|
+
model=model,
|
|
193
|
+
metrics={
|
|
194
|
+
"acc@1": Accuracy(task="multiclass", num_classes=num_classes),
|
|
195
|
+
"acc@5": Accuracy(
|
|
196
|
+
task="multiclass", num_classes=num_classes, top_k=5
|
|
197
|
+
),
|
|
198
|
+
},
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
trainer = L.Trainer(
|
|
202
|
+
devices=1, num_nodes=1, logger=False, fast_dev_run=RuntimeConstants.debug
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
test_metrics = trainer.test(lit_module, dataloaders=test_loader)
|
|
206
|
+
log.info(f"Test metrics: {test_metrics}")
|
|
207
|
+
return model
|
|
208
|
+
|
|
209
|
+
def get_dataloader(self, dataset, stage: str):
|
|
210
|
+
assert stage in ["train", "val", "test"], f"Invalid stage: {stage}"
|
|
211
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
212
|
+
if "shuffle" not in dataloader_kwargs:
|
|
213
|
+
dataloader_kwargs["shuffle"] = stage == "train"
|
|
214
|
+
return DataLoader(dataset, **dataloader_kwargs)
|
fusion_bench/method/opcm/opcm.py
CHANGED
|
@@ -87,6 +87,7 @@ class OPCMForCLIP(
|
|
|
87
87
|
# get the average model
|
|
88
88
|
with self.profile("loading model"):
|
|
89
89
|
merged_model = modelpool.load_model(model_names[0])
|
|
90
|
+
assert merged_model is not None, "Failed to load the first model"
|
|
90
91
|
|
|
91
92
|
if self.evaluate_on_every_step:
|
|
92
93
|
with self.profile("evaluating model"):
|
|
@@ -15,7 +15,7 @@ from fusion_bench.utils.state_dict_arithmetic import (
|
|
|
15
15
|
state_dict_add,
|
|
16
16
|
state_dict_binary_mask,
|
|
17
17
|
state_dict_diff_abs,
|
|
18
|
-
|
|
18
|
+
state_dict_hadamard_product,
|
|
19
19
|
state_dict_mul,
|
|
20
20
|
state_dict_sub,
|
|
21
21
|
state_dict_sum,
|
|
@@ -111,7 +111,7 @@ class TallMaskTaskArithmeticAlgorithm(
|
|
|
111
111
|
|
|
112
112
|
with self.profile("compress and retrieve"):
|
|
113
113
|
for model_name in modelpool.model_names:
|
|
114
|
-
retrieved_task_vector =
|
|
114
|
+
retrieved_task_vector = state_dict_hadamard_product(
|
|
115
115
|
tall_masks[model_name], multi_task_vector
|
|
116
116
|
)
|
|
117
117
|
retrieved_state_dict = state_dict_add(
|
fusion_bench/mixins/__init__.py
CHANGED
|
@@ -11,6 +11,7 @@ _import_structure = {
|
|
|
11
11
|
"hydra_config": ["HydraConfigMixin"],
|
|
12
12
|
"lightning_fabric": ["LightningFabricMixin"],
|
|
13
13
|
"openclip_classification": ["OpenCLIPClassificationMixin"],
|
|
14
|
+
"pyinstrument": ["PyinstrumentProfilerMixin"],
|
|
14
15
|
"serialization": [
|
|
15
16
|
"BaseYAMLSerializable",
|
|
16
17
|
"YAMLSerializationMixin",
|
|
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
|
|
25
26
|
from .hydra_config import HydraConfigMixin
|
|
26
27
|
from .lightning_fabric import LightningFabricMixin
|
|
27
28
|
from .openclip_classification import OpenCLIPClassificationMixin
|
|
29
|
+
from .pyinstrument import PyinstrumentProfilerMixin
|
|
28
30
|
from .serialization import (
|
|
29
31
|
BaseYAMLSerializable,
|
|
30
32
|
YAMLSerializationMixin,
|