fusion-bench 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +4 -0
- fusion_bench/dataset/clip_dataset.py +1 -0
- fusion_bench/method/__init__.py +2 -0
- fusion_bench/method/adamerging/__init__.py +28 -5
- fusion_bench/method/adamerging/resnet_adamerging.py +279 -0
- fusion_bench/method/adamerging/task_wise_adamerging.py +2 -14
- fusion_bench/method/adamerging/utils.py +58 -0
- fusion_bench/method/classification/image_classification_finetune.py +168 -12
- fusion_bench/method/dare/simple_average.py +3 -2
- fusion_bench/method/dare/task_arithmetic.py +3 -2
- fusion_bench/method/simple_average.py +6 -4
- fusion_bench/method/task_arithmetic/task_arithmetic.py +4 -1
- fusion_bench/mixins/lightning_fabric.py +9 -0
- fusion_bench/modelpool/__init__.py +24 -2
- fusion_bench/modelpool/base_pool.py +8 -1
- fusion_bench/modelpool/causal_lm/causal_lm.py +2 -1
- fusion_bench/modelpool/convnext_for_image_classification.py +198 -0
- fusion_bench/modelpool/dinov2_for_image_classification.py +197 -0
- fusion_bench/modelpool/resnet_for_image_classification.py +289 -5
- fusion_bench/models/hf_clip.py +4 -7
- fusion_bench/models/hf_utils.py +4 -1
- fusion_bench/models/model_card_templates/default.md +1 -1
- fusion_bench/taskpool/__init__.py +2 -0
- fusion_bench/taskpool/clip_vision/taskpool.py +1 -1
- fusion_bench/taskpool/resnet_for_image_classification.py +231 -0
- fusion_bench/utils/json.py +49 -8
- fusion_bench/utils/state_dict_arithmetic.py +91 -10
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/METADATA +2 -2
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/RECORD +124 -62
- fusion_bench_config/fabric/auto.yaml +1 -1
- fusion_bench_config/fabric/loggers/swandb_logger.yaml +5 -0
- fusion_bench_config/fabric/loggers/wandb_logger.yaml +2 -0
- fusion_bench_config/fabric_model_fusion.yaml +1 -0
- fusion_bench_config/method/adamerging/resnet.yaml +18 -0
- fusion_bench_config/method/classification/clip_finetune.yaml +5 -0
- fusion_bench_config/method/classification/image_classification_finetune.yaml +9 -0
- fusion_bench_config/method/linear/expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo.yaml +5 -0
- fusion_bench_config/method/linear/llama_expo_with_dare.yaml +3 -0
- fusion_bench_config/method/linear/simple_average_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +3 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +5 -0
- fusion_bench_config/method/linear/weighted_average_for_llama.yaml +5 -0
- fusion_bench_config/method/mixtral_moe_merging.yaml +3 -0
- fusion_bench_config/method/mixtral_moe_upscaling.yaml +5 -0
- fusion_bench_config/method/regmean/clip_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/gpt2_regmean.yaml +3 -0
- fusion_bench_config/method/regmean/regmean.yaml +3 -0
- fusion_bench_config/method/regmean_plusplus/clip_regmean_plusplus.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/causal_lm_upscaling.yaml +6 -0
- fusion_bench_config/method/smile_upscaling/error_accumulation.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/projected_energy.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/singular_projection_merging.yaml +3 -0
- fusion_bench_config/method/smile_upscaling/smile_mistral_upscaling.yaml +5 -0
- fusion_bench_config/method/smile_upscaling/smile_qwen2_upscaling.yaml +5 -0
- fusion_bench_config/method/wudi/wudi.yaml +3 -0
- fusion_bench_config/model_fusion.yaml +2 -1
- fusion_bench_config/modelpool/ConvNextForImageClassification/convnext-base-224.yaml +10 -0
- fusion_bench_config/modelpool/Dinov2ForImageClassification/dinov2-base-imagenet1k-1-layer.yaml +10 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/_generate_config.py +138 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet152_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet152_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet18_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet18_svhn.yaml +14 -0
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar10.yaml +1 -1
- fusion_bench_config/modelpool/{ResNetForImageClassfication → ResNetForImageClassification}/transformers/resnet50_cifar100.yaml +1 -1
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_dtd.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_emnist_letters.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_eurosat.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fashion_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_fer2013.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_food101.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_gtsrb.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_kmnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_mnist.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford-iiit-pet.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_oxford_flowers102.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_pcam.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_rendered-sst2.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_resisc45.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stanford-cars.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_stl10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_sun397.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassification/transformers/resnet50_svhn.yaml +14 -0
- fusion_bench_config/method/clip_finetune.yaml +0 -26
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.26.dist-info → fusion_bench-0.2.28.dist-info}/top_level.txt +0 -0
fusion_bench/__init__.py
CHANGED
|
@@ -41,6 +41,8 @@ _import_structure = {
|
|
|
41
41
|
"CausalLMBackbonePool",
|
|
42
42
|
"CausalLMPool",
|
|
43
43
|
"CLIPVisionModelPool",
|
|
44
|
+
"ConvNextForImageClassificationPool",
|
|
45
|
+
"Dinov2ForImageClassificationPool",
|
|
44
46
|
"GPT2ForSequenceClassificationPool",
|
|
45
47
|
"HuggingFaceGPT2ClassificationPool",
|
|
46
48
|
"NYUv2ModelPool",
|
|
@@ -107,6 +109,8 @@ if TYPE_CHECKING:
|
|
|
107
109
|
CausalLMBackbonePool,
|
|
108
110
|
CausalLMPool,
|
|
109
111
|
CLIPVisionModelPool,
|
|
112
|
+
ConvNextForImageClassificationPool,
|
|
113
|
+
Dinov2ForImageClassificationPool,
|
|
110
114
|
GPT2ForSequenceClassificationPool,
|
|
111
115
|
HuggingFaceGPT2ClassificationPool,
|
|
112
116
|
NYUv2ModelPool,
|
|
@@ -62,6 +62,7 @@ class CLIPDataset(torch.utils.data.Dataset):
|
|
|
62
62
|
if self.processor is not None:
|
|
63
63
|
if isinstance(self.processor, (ProcessorMixin, BaseImageProcessor)):
|
|
64
64
|
# Apply the processor to the image to get the input tensor
|
|
65
|
+
image = image.convert("RGB") # ensure image is in RGB format
|
|
65
66
|
inputs = self.processor(images=[image], return_tensors="pt")[
|
|
66
67
|
"pixel_values"
|
|
67
68
|
][0]
|
fusion_bench/method/__init__.py
CHANGED
|
@@ -55,6 +55,8 @@ _import_structure = {
|
|
|
55
55
|
"GPT2LayerWiseAdaMergingAlgorithm",
|
|
56
56
|
"LayerWiseAdaMergingForLlamaSFT",
|
|
57
57
|
"FlanT5LayerWiseAdaMergingAlgorithm",
|
|
58
|
+
"ResNetLayerWiseAdamerging",
|
|
59
|
+
"ResNetTaskWiseAdamerging",
|
|
58
60
|
],
|
|
59
61
|
"pwe_moe": [
|
|
60
62
|
"PWEMoELinearScalarizationForCLIP",
|
|
@@ -1,6 +1,29 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
from .
|
|
6
|
-
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
|
+
|
|
7
|
+
_import_structure = {
|
|
8
|
+
"clip_layer_wise_adamerging": ["CLIPLayerWiseAdaMergingAlgorithm"],
|
|
9
|
+
"clip_task_wise_adamerging": ["CLIPTaskWiseAdaMergingAlgorithm"],
|
|
10
|
+
"flan_t5_layer_wise_adamerging": ["FlanT5LayerWiseAdaMergingAlgorithm"],
|
|
11
|
+
"gpt2_layer_wise_adamerging": ["GPT2LayerWiseAdaMergingAlgorithm"],
|
|
12
|
+
"llama_adamerging": ["LayerWiseAdaMergingForLlamaSFT"],
|
|
13
|
+
"resnet_adamerging": ["ResNetLayerWiseAdamerging", "ResNetTaskWiseAdamerging"],
|
|
14
|
+
}
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from .clip_layer_wise_adamerging import CLIPLayerWiseAdaMergingAlgorithm
|
|
18
|
+
from .clip_task_wise_adamerging import CLIPTaskWiseAdaMergingAlgorithm
|
|
19
|
+
from .flan_t5_layer_wise_adamerging import FlanT5LayerWiseAdaMergingAlgorithm
|
|
20
|
+
from .gpt2_layer_wise_adamerging import GPT2LayerWiseAdaMergingAlgorithm
|
|
21
|
+
from .llama_adamerging import LayerWiseAdaMergingForLlamaSFT
|
|
22
|
+
from .resnet_adamerging import ResNetLayerWiseAdamerging, ResNetTaskWiseAdamerging
|
|
23
|
+
|
|
24
|
+
else:
|
|
25
|
+
sys.modules[__name__] = LazyImporter(
|
|
26
|
+
__name__,
|
|
27
|
+
globals()["__file__"],
|
|
28
|
+
_import_structure,
|
|
29
|
+
)
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Iterator, Optional, Union, override
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from omegaconf import DictConfig
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.utils.data import DataLoader
|
|
9
|
+
from tqdm import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench import (
|
|
12
|
+
BaseAlgorithm,
|
|
13
|
+
LightningFabricMixin,
|
|
14
|
+
auto_register_config,
|
|
15
|
+
get_rankzero_logger,
|
|
16
|
+
instantiate,
|
|
17
|
+
)
|
|
18
|
+
from fusion_bench.constants import RuntimeConstants
|
|
19
|
+
from fusion_bench.dataset import CLIPDataset
|
|
20
|
+
from fusion_bench.modelpool import ResNetForImageClassificationPool
|
|
21
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import LayerWiseMergedModel
|
|
22
|
+
from fusion_bench.models.wrappers.task_wise_fusion import TaskWiseMergedModel
|
|
23
|
+
from fusion_bench.utils import load_tensor_from_file
|
|
24
|
+
from fusion_bench.utils.data import InfiniteDataLoader
|
|
25
|
+
|
|
26
|
+
from .entropy_loss import entropy_loss
|
|
27
|
+
from .utils import construct_layer_wise_merged_model, construct_task_wise_merged_model
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from transformers import ResNetForImageClassification, ResNetModel
|
|
31
|
+
|
|
32
|
+
log = get_rankzero_logger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@auto_register_config
|
|
36
|
+
class _ResNetAdaMergingBase(
|
|
37
|
+
ABC,
|
|
38
|
+
LightningFabricMixin,
|
|
39
|
+
BaseAlgorithm,
|
|
40
|
+
):
|
|
41
|
+
classification_heads: Dict[str, nn.Module]
|
|
42
|
+
shuffled_test_loader_iters: Dict[str, Iterator]
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
max_steps: int,
|
|
47
|
+
optimizer: DictConfig,
|
|
48
|
+
lr_scheduler: DictConfig,
|
|
49
|
+
dataloader_kwargs: DictConfig,
|
|
50
|
+
init_values: Optional[float],
|
|
51
|
+
clamp_weights: bool = False,
|
|
52
|
+
tie_weights: bool = True,
|
|
53
|
+
strict: bool = False,
|
|
54
|
+
resume_weights_path: Union[str, None] = None,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
super().__init__(**kwargs)
|
|
58
|
+
if RuntimeConstants.debug:
|
|
59
|
+
log.info("Debug mode is on, setting max_steps to 10")
|
|
60
|
+
self.max_steps = 10
|
|
61
|
+
|
|
62
|
+
@override
|
|
63
|
+
def run(self, modelpool: ResNetForImageClassificationPool):
|
|
64
|
+
self.modelpool = modelpool
|
|
65
|
+
|
|
66
|
+
# setup models
|
|
67
|
+
wrapped_model = self.setup_wrapped_model(modelpool)
|
|
68
|
+
|
|
69
|
+
# if max_steps <= 0, skip training and return the merged model directly
|
|
70
|
+
# this can be used to evaluate the merging weights loaded from `resume_weights_path`
|
|
71
|
+
if self.max_steps <= 0:
|
|
72
|
+
# skip_training
|
|
73
|
+
return wrapped_model.merge_and_unload()
|
|
74
|
+
|
|
75
|
+
# setup dataloaders
|
|
76
|
+
self.setup_dataloaders()
|
|
77
|
+
|
|
78
|
+
# configure optimizer and lr_scheduler
|
|
79
|
+
optimizer = instantiate(self.optimizer, params=[wrapped_model.merge_weight])
|
|
80
|
+
if self.lr_scheduler is not None:
|
|
81
|
+
lr_scheduler = instantiate(self.lr_scheduler, optimizer=optimizer)
|
|
82
|
+
else:
|
|
83
|
+
lr_scheduler = None
|
|
84
|
+
|
|
85
|
+
wrapped_model, optimizer = self.fabric.setup(wrapped_model, optimizer)
|
|
86
|
+
wrapped_model = self.test_time_adaptation(
|
|
87
|
+
wrapped_model, optimizer, lr_scheduler
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
# save merging weights
|
|
91
|
+
if self.log_dir is not None:
|
|
92
|
+
self.fabric.save(
|
|
93
|
+
os.path.join(self.log_dir, "checkpoints", "merge_weight.ckpt"),
|
|
94
|
+
{"merge_weight": wrapped_model.merge_weight},
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
merged_model = wrapped_model.merge_and_unload()
|
|
98
|
+
if self.log_dir is not None:
|
|
99
|
+
modelpool.save_model(
|
|
100
|
+
merged_model,
|
|
101
|
+
os.path.join(self.log_dir, "checkpoints", "merged_model"),
|
|
102
|
+
algorithm_config=self.config,
|
|
103
|
+
description="Merged ResNet model using AdaMerging (E Yang, 2023).",
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
return merged_model
|
|
107
|
+
|
|
108
|
+
def test_time_adaptation(
|
|
109
|
+
self,
|
|
110
|
+
wrapped_model: TaskWiseMergedModel,
|
|
111
|
+
optimizer: torch.optim.Optimizer,
|
|
112
|
+
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler],
|
|
113
|
+
):
|
|
114
|
+
model_names = self.modelpool.model_names
|
|
115
|
+
wrapped_model.train()
|
|
116
|
+
wrapped_model.merge_weights()
|
|
117
|
+
|
|
118
|
+
for step_idx in tqdm(
|
|
119
|
+
range(self.max_steps),
|
|
120
|
+
disable=not self.fabric.is_global_zero,
|
|
121
|
+
dynamic_ncols=True,
|
|
122
|
+
):
|
|
123
|
+
metrics = {"tta/total_loss": 0.0}
|
|
124
|
+
for task in model_names:
|
|
125
|
+
batch = next(self.get_shuffled_test_loader_iter(task))
|
|
126
|
+
logits = self.compute_logits(wrapped_model, batch[0], task)
|
|
127
|
+
loss = entropy_loss(logits)
|
|
128
|
+
metrics[f"tta/{task}_loss"] = loss.item()
|
|
129
|
+
metrics["tta/total_loss"] += loss.item()
|
|
130
|
+
self.fabric.backward(loss, retain_graph=True)
|
|
131
|
+
|
|
132
|
+
optimizer.step()
|
|
133
|
+
optimizer.zero_grad()
|
|
134
|
+
wrapped_model.merge_weights() # merge weights for the next step
|
|
135
|
+
if lr_scheduler is not None:
|
|
136
|
+
lr_scheduler.step()
|
|
137
|
+
|
|
138
|
+
self.fabric.log_dict(metrics=metrics, step=step_idx)
|
|
139
|
+
|
|
140
|
+
return wrapped_model
|
|
141
|
+
|
|
142
|
+
def compute_logits(
|
|
143
|
+
self, module: Union["ResNetModel", nn.Module], images: torch.Tensor, task: str
|
|
144
|
+
) -> torch.Tensor:
|
|
145
|
+
if self.modelpool.type == "transformers":
|
|
146
|
+
outputs = module(images, return_dict=True)
|
|
147
|
+
pooled_output = outputs.pooler_output
|
|
148
|
+
logits = self.classification_heads[task](pooled_output)
|
|
149
|
+
return logits
|
|
150
|
+
else:
|
|
151
|
+
raise NotImplementedError(
|
|
152
|
+
f"Model type {self.modelpool.type} is not supported."
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def setup_dataloaders(self):
|
|
156
|
+
dataloader_kwargs = dict(self.dataloader_kwargs)
|
|
157
|
+
dataloader_kwargs["shuffle"] = True # ensure shuffling for TTA
|
|
158
|
+
processor = self.modelpool.load_processor()
|
|
159
|
+
for task in self.modelpool.test_dataset_names:
|
|
160
|
+
test_dataset = self.modelpool.load_test_dataset(task)
|
|
161
|
+
test_dataset = CLIPDataset(test_dataset, processor=processor)
|
|
162
|
+
test_loader = DataLoader(test_dataset, **dataloader_kwargs)
|
|
163
|
+
self.shuffled_test_loader_iters[task] = iter(
|
|
164
|
+
InfiniteDataLoader(test_loader)
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
def get_shuffled_test_loader_iter(self, task: str):
|
|
168
|
+
return self.shuffled_test_loader_iters[task]
|
|
169
|
+
|
|
170
|
+
@abstractmethod
|
|
171
|
+
def setup_wrapped_model(
|
|
172
|
+
self, modelpool: ResNetForImageClassificationPool
|
|
173
|
+
) -> Union[TaskWiseMergedModel, LayerWiseMergedModel]:
|
|
174
|
+
"""
|
|
175
|
+
Setup the wrapped merged model.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
modelpool (ResNetForImageClassificationPool): The model pool containing pretrained and finetuned models.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Union[TaskWiseMergedModel, LayerWiseMergedModel] : The wrapped merged model.
|
|
182
|
+
"""
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
class ResNetTaskWiseAdamerging(_ResNetAdaMergingBase):
|
|
187
|
+
@torch.no_grad()
|
|
188
|
+
def setup_wrapped_model(self, modelpool: ResNetForImageClassificationPool):
|
|
189
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
190
|
+
finetuned_models = dict(modelpool.named_models())
|
|
191
|
+
|
|
192
|
+
if modelpool.type == "transformers":
|
|
193
|
+
pretrained_model: "ResNetForImageClassification"
|
|
194
|
+
finetuned_models: Dict[str, "ResNetForImageClassification"]
|
|
195
|
+
for model_name in finetuned_models:
|
|
196
|
+
self.classification_heads[model_name] = finetuned_models[
|
|
197
|
+
model_name
|
|
198
|
+
].classifier
|
|
199
|
+
# fix the classification head during merging and move to device
|
|
200
|
+
self.classification_heads[model_name].requires_grad_(False)
|
|
201
|
+
pretrained_backbone: "ResNetModel" = pretrained_model.resnet
|
|
202
|
+
finetuned_backbones = [
|
|
203
|
+
finetuned_models[model_name].resnet for model_name in finetuned_models
|
|
204
|
+
]
|
|
205
|
+
else:
|
|
206
|
+
raise NotImplementedError(f"Model type {modelpool.type} is not supported.")
|
|
207
|
+
|
|
208
|
+
wrapped_model = construct_task_wise_merged_model(
|
|
209
|
+
pretrained_model=pretrained_backbone,
|
|
210
|
+
finetuned_models=finetuned_backbones,
|
|
211
|
+
clamp_weights=self.clamp_weights,
|
|
212
|
+
tie_weights=self.tie_weights,
|
|
213
|
+
strict=self.strict,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
if self.init_values is not None:
|
|
217
|
+
log.info(f"Initializing merging weights to {self.init_values}")
|
|
218
|
+
wrapped_model.merge_weight.data.fill_(self.init_values)
|
|
219
|
+
|
|
220
|
+
# load merging weights if provided
|
|
221
|
+
if self.resume_weights_path is not None:
|
|
222
|
+
merging_weights = load_tensor_from_file(
|
|
223
|
+
self.resume_weights_path, device="cpu"
|
|
224
|
+
)
|
|
225
|
+
log.info(f"Loaded merging weights from {self.resume_weights_path}")
|
|
226
|
+
assert merging_weights.shape == wrapped_model.merge_weight.shape, (
|
|
227
|
+
f"Merging weights shape {merging_weights.shape} does not match "
|
|
228
|
+
f"model's merge_weight shape {wrapped_model.merge_weight.shape}."
|
|
229
|
+
)
|
|
230
|
+
wrapped_model.merge_weight.data = merging_weights
|
|
231
|
+
return wrapped_model
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class ResNetLayerWiseAdamerging(_ResNetAdaMergingBase):
|
|
235
|
+
@torch.no_grad()
|
|
236
|
+
def setup_wrapped_model(self, modelpool: ResNetForImageClassificationPool):
|
|
237
|
+
pretrained_model = modelpool.load_pretrained_model()
|
|
238
|
+
finetuned_models = dict(modelpool.named_models())
|
|
239
|
+
|
|
240
|
+
if modelpool.type == "transformers":
|
|
241
|
+
pretrained_model: "ResNetForImageClassification"
|
|
242
|
+
finetuned_models: Dict[str, "ResNetForImageClassification"]
|
|
243
|
+
for model_name in finetuned_models:
|
|
244
|
+
self.classification_heads[model_name] = finetuned_models[
|
|
245
|
+
model_name
|
|
246
|
+
].classifier
|
|
247
|
+
# fix the classification head during merging and move to device
|
|
248
|
+
self.classification_heads[model_name].requires_grad_(False)
|
|
249
|
+
pretrained_backbone: "ResNetModel" = pretrained_model.resnet
|
|
250
|
+
finetuned_backbones = [
|
|
251
|
+
finetuned_models[model_name].resnet for model_name in finetuned_models
|
|
252
|
+
]
|
|
253
|
+
else:
|
|
254
|
+
raise NotImplementedError(f"Model type {modelpool.type} is not supported.")
|
|
255
|
+
|
|
256
|
+
wrapped_model = construct_layer_wise_merged_model(
|
|
257
|
+
pretrained_model=pretrained_backbone,
|
|
258
|
+
finetuned_models=finetuned_backbones,
|
|
259
|
+
clamp_weights=self.clamp_weights,
|
|
260
|
+
tie_weights=self.tie_weights,
|
|
261
|
+
strict=self.strict,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
if self.init_values is not None:
|
|
265
|
+
log.info(f"Initializing merging weights to {self.init_values}")
|
|
266
|
+
wrapped_model.merge_weight.data.fill_(self.init_values)
|
|
267
|
+
|
|
268
|
+
# load merging weights if provided
|
|
269
|
+
if self.resume_weights_path is not None:
|
|
270
|
+
merging_weights = load_tensor_from_file(
|
|
271
|
+
self.resume_weights_path, device="cpu"
|
|
272
|
+
)
|
|
273
|
+
log.info(f"Loaded merging weights from {self.resume_weights_path}")
|
|
274
|
+
assert merging_weights.shape == wrapped_model.merge_weight.shape, (
|
|
275
|
+
f"Merging weights shape {merging_weights.shape} does not match "
|
|
276
|
+
f"model's merge_weight shape {wrapped_model.merge_weight.shape}."
|
|
277
|
+
)
|
|
278
|
+
wrapped_model.merge_weight.data = merging_weights
|
|
279
|
+
return wrapped_model
|
|
@@ -18,21 +18,9 @@ from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
|
18
18
|
get_task_wise_weights,
|
|
19
19
|
)
|
|
20
20
|
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def entropy_loss(logits: Tensor) -> Tensor:
|
|
25
|
-
"""
|
|
26
|
-
Compute the entropy loss of a set of logits.
|
|
21
|
+
from .entropy_loss import entropy_loss
|
|
27
22
|
|
|
28
|
-
|
|
29
|
-
logits (Tensor): The logits to compute the entropy loss of.
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
Tensor: The entropy loss of the logits.
|
|
33
|
-
"""
|
|
34
|
-
probs = torch.softmax(logits, dim=-1)
|
|
35
|
-
return -torch.sum(probs * torch.log(probs + 1e-8), dim=-1).mean()
|
|
23
|
+
log = logging.getLogger(__name__)
|
|
36
24
|
|
|
37
25
|
|
|
38
26
|
class TaskWiseAdaMergingAlgorithm(ModelFusionAlgorithm):
|
|
@@ -1,4 +1,9 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
1
3
|
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
|
|
6
|
+
from fusion_bench.utils.type import TorchModelType
|
|
2
7
|
|
|
3
8
|
|
|
4
9
|
def get_memory_usage(desc):
|
|
@@ -13,3 +18,56 @@ def get_memory_usage(desc):
|
|
|
13
18
|
return (
|
|
14
19
|
f"{desc}\nAllocated Memory: {allocated:.2f} MB\nCached Memory: {cached:.2f} MB"
|
|
15
20
|
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@torch.no_grad()
|
|
24
|
+
def construct_task_wise_merged_model(
|
|
25
|
+
pretrained_model: TorchModelType,
|
|
26
|
+
finetuned_models: List[TorchModelType],
|
|
27
|
+
clamp_weights: bool = False,
|
|
28
|
+
tie_weights: bool = True,
|
|
29
|
+
strict: bool = False,
|
|
30
|
+
):
|
|
31
|
+
from fusion_bench.models.wrappers.task_wise_fusion import (
|
|
32
|
+
TaskWiseMergedModel,
|
|
33
|
+
get_task_wise_weights,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
merging_weights = get_task_wise_weights(num_models=len(finetuned_models))
|
|
37
|
+
module = TaskWiseMergedModel(
|
|
38
|
+
task_wise_weight=merging_weights,
|
|
39
|
+
pretrained_model=pretrained_model,
|
|
40
|
+
finetuned_models=finetuned_models,
|
|
41
|
+
clamp_weights=clamp_weights,
|
|
42
|
+
tie_weights=tie_weights,
|
|
43
|
+
strict=strict,
|
|
44
|
+
)
|
|
45
|
+
return module
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@torch.no_grad()
|
|
49
|
+
def construct_layer_wise_merged_model(
|
|
50
|
+
pretrained_model: TorchModelType,
|
|
51
|
+
finetuned_models: List[TorchModelType],
|
|
52
|
+
clamp_weights: bool = False,
|
|
53
|
+
tie_weights: bool = True,
|
|
54
|
+
strict: bool = False,
|
|
55
|
+
):
|
|
56
|
+
from fusion_bench.models.wrappers.layer_wise_fusion import (
|
|
57
|
+
LayerWiseMergedModel,
|
|
58
|
+
get_layer_wise_weights,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
merging_weights = get_layer_wise_weights(
|
|
62
|
+
num_models=len(finetuned_models),
|
|
63
|
+
num_layers=len([p for p in pretrained_model.parameters() if p.requires_grad]),
|
|
64
|
+
)
|
|
65
|
+
module = LayerWiseMergedModel(
|
|
66
|
+
layer_wise_weight=merging_weights,
|
|
67
|
+
pretrained_model=pretrained_model,
|
|
68
|
+
finetuned_models=finetuned_models,
|
|
69
|
+
clamp_weights=clamp_weights,
|
|
70
|
+
tie_weights=tie_weights,
|
|
71
|
+
strict=strict,
|
|
72
|
+
)
|
|
73
|
+
return module
|