fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +18 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +58 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +8 -3
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,208 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
TYPE_CHECKING,
|
|
3
|
+
Any,
|
|
4
|
+
Callable,
|
|
5
|
+
Dict,
|
|
6
|
+
Literal,
|
|
7
|
+
Optional,
|
|
8
|
+
TypeVar,
|
|
9
|
+
Union,
|
|
10
|
+
override,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from omegaconf import DictConfig
|
|
15
|
+
from torch import nn
|
|
16
|
+
|
|
17
|
+
from fusion_bench import BaseModelPool, auto_register_config, get_rankzero_logger
|
|
18
|
+
from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from torchvision.models import ResNet as TorchVisionResNet
|
|
22
|
+
|
|
23
|
+
log = get_rankzero_logger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_torchvision_resnet(
|
|
27
|
+
model_name: str, weights: Optional[str], num_classes: Optional[int]
|
|
28
|
+
) -> "TorchVisionResNet":
|
|
29
|
+
import torchvision.models
|
|
30
|
+
|
|
31
|
+
model_fn = getattr(torchvision.models, model_name)
|
|
32
|
+
model: "TorchVisionResNet" = model_fn(weights=weights)
|
|
33
|
+
|
|
34
|
+
if num_classes is not None:
|
|
35
|
+
model.fc = nn.Linear(model.fc.in_features, num_classes)
|
|
36
|
+
|
|
37
|
+
return model
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def load_transformers_resnet(
|
|
41
|
+
config_path: str, pretrained: bool, dataset_name: Optional[str]
|
|
42
|
+
):
|
|
43
|
+
from transformers import AutoConfig, ResNetForImageClassification
|
|
44
|
+
|
|
45
|
+
if pretrained:
|
|
46
|
+
model = ResNetForImageClassification.from_pretrained(config_path)
|
|
47
|
+
else:
|
|
48
|
+
config = AutoConfig.from_pretrained(config_path)
|
|
49
|
+
model = ResNetForImageClassification(config)
|
|
50
|
+
|
|
51
|
+
if dataset_name is None:
|
|
52
|
+
return model
|
|
53
|
+
|
|
54
|
+
classnames = get_classnames(dataset_name)
|
|
55
|
+
id2label = {i: c for i, c in enumerate(classnames)}
|
|
56
|
+
label2id = {c: i for i, c in enumerate(classnames)}
|
|
57
|
+
model.config.id2label = id2label
|
|
58
|
+
model.config.label2id = label2id
|
|
59
|
+
|
|
60
|
+
model.classifier[1] = (
|
|
61
|
+
nn.Linear(
|
|
62
|
+
model.classifier[1].in_features,
|
|
63
|
+
len(classnames),
|
|
64
|
+
)
|
|
65
|
+
if model.config.num_labels > 0
|
|
66
|
+
else nn.Identity()
|
|
67
|
+
)
|
|
68
|
+
return model
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@auto_register_config
|
|
72
|
+
class ResNetForImageClassificationPool(BaseModelPool):
|
|
73
|
+
def __init__(self, type: str, **kwargs):
|
|
74
|
+
super().__init__(**kwargs)
|
|
75
|
+
assert type in ["torchvision", "transformers"]
|
|
76
|
+
|
|
77
|
+
def load_processor(
|
|
78
|
+
self, stage: Literal["train", "val", "test"] = "test", *args, **kwargs
|
|
79
|
+
):
|
|
80
|
+
if self.type == "torchvision":
|
|
81
|
+
from torchvision import transforms
|
|
82
|
+
|
|
83
|
+
to_tensor = transforms.ToTensor()
|
|
84
|
+
normalize = transforms.Normalize(
|
|
85
|
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
|
86
|
+
)
|
|
87
|
+
if stage == "train":
|
|
88
|
+
train_transform = transforms.Compose(
|
|
89
|
+
[
|
|
90
|
+
transforms.RandomResizedCrop(224),
|
|
91
|
+
transforms.RandomHorizontalFlip(),
|
|
92
|
+
to_tensor,
|
|
93
|
+
normalize,
|
|
94
|
+
]
|
|
95
|
+
)
|
|
96
|
+
return train_transform
|
|
97
|
+
else:
|
|
98
|
+
val_transform = transforms.Compose(
|
|
99
|
+
[
|
|
100
|
+
transforms.Resize(256),
|
|
101
|
+
transforms.CenterCrop(224),
|
|
102
|
+
to_tensor,
|
|
103
|
+
normalize,
|
|
104
|
+
]
|
|
105
|
+
)
|
|
106
|
+
return val_transform
|
|
107
|
+
|
|
108
|
+
elif self.type == "transformers":
|
|
109
|
+
from transformers import AutoImageProcessor
|
|
110
|
+
|
|
111
|
+
if self.has_pretrained:
|
|
112
|
+
config_path = self._models["_pretrained_"].config_path
|
|
113
|
+
else:
|
|
114
|
+
for model_cfg in self._models.values():
|
|
115
|
+
if isinstance(model_cfg, str):
|
|
116
|
+
config_path = model_cfg
|
|
117
|
+
break
|
|
118
|
+
if "config_path" in model_cfg:
|
|
119
|
+
config_path = model_cfg["config_path"]
|
|
120
|
+
break
|
|
121
|
+
return AutoImageProcessor.from_pretrained(config_path)
|
|
122
|
+
|
|
123
|
+
@override
|
|
124
|
+
def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
|
|
125
|
+
log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
|
|
126
|
+
if (
|
|
127
|
+
isinstance(model_name_or_config, str)
|
|
128
|
+
and model_name_or_config in self._models
|
|
129
|
+
):
|
|
130
|
+
model_name_or_config = self._models[model_name_or_config]
|
|
131
|
+
|
|
132
|
+
if self.type == "torchvision":
|
|
133
|
+
from torchvision.models import (
|
|
134
|
+
resnet18,
|
|
135
|
+
resnet34,
|
|
136
|
+
resnet50,
|
|
137
|
+
resnet101,
|
|
138
|
+
resnet152,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
match model_name_or_config:
|
|
142
|
+
case "resnet18":
|
|
143
|
+
model = resnet18()
|
|
144
|
+
case "resnet34":
|
|
145
|
+
model = resnet34()
|
|
146
|
+
case "resnet50":
|
|
147
|
+
model = resnet50()
|
|
148
|
+
case "resnet101":
|
|
149
|
+
model = resnet101()
|
|
150
|
+
case "resnet152":
|
|
151
|
+
model = resnet152()
|
|
152
|
+
case dict() | DictConfig() as model_config:
|
|
153
|
+
if "dataset_name" in model_config:
|
|
154
|
+
num_classes = get_num_classes(model_config["dataset_name"])
|
|
155
|
+
if "num_classes" in model_config:
|
|
156
|
+
assert (
|
|
157
|
+
num_classes == model_config["num_classes"]
|
|
158
|
+
), f"num_classes mismatch: {num_classes} vs {model_config['num_classes']}"
|
|
159
|
+
elif "num_classes" in model_config:
|
|
160
|
+
num_classes = model_config["num_classes"]
|
|
161
|
+
else:
|
|
162
|
+
num_classes = None
|
|
163
|
+
model = load_torchvision_resnet(
|
|
164
|
+
model_name=model_config["model_name"],
|
|
165
|
+
weights=model_config.get("weights", None),
|
|
166
|
+
num_classes=num_classes,
|
|
167
|
+
)
|
|
168
|
+
case _:
|
|
169
|
+
raise ValueError(
|
|
170
|
+
f"Invalid model_name_or_config type: {type(model_name_or_config)}"
|
|
171
|
+
)
|
|
172
|
+
elif self.type == "transformers":
|
|
173
|
+
match model_name_or_config:
|
|
174
|
+
case str() as model_path:
|
|
175
|
+
from transformers import AutoModelForImageClassification
|
|
176
|
+
|
|
177
|
+
model = AutoModelForImageClassification.from_pretrained(model_path)
|
|
178
|
+
case dict() | DictConfig() as model_config:
|
|
179
|
+
|
|
180
|
+
model = load_transformers_resnet(
|
|
181
|
+
config_path=model_config["config_path"],
|
|
182
|
+
pretrained=model_config.get("pretrained", False),
|
|
183
|
+
dataset_name=model_config.get("dataset_name", None),
|
|
184
|
+
)
|
|
185
|
+
case _:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Invalid model_name_or_config type: {type(model_name_or_config)}"
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# override forward to return logits only
|
|
191
|
+
original_forward = model.forward
|
|
192
|
+
model.forward = lambda pixel_values, **kwargs: original_forward(
|
|
193
|
+
pixel_values=pixel_values, **kwargs
|
|
194
|
+
).logits
|
|
195
|
+
model.original_forward = original_forward
|
|
196
|
+
else:
|
|
197
|
+
raise ValueError(f"Unknown model type: {self.type}")
|
|
198
|
+
return model
|
|
199
|
+
|
|
200
|
+
@override
|
|
201
|
+
def save_model(self, model, path, *args, **kwargs):
|
|
202
|
+
if self.type == "torchvision":
|
|
203
|
+
torch.save(model.state_dict(), path)
|
|
204
|
+
elif self.type == "transformers":
|
|
205
|
+
model.save_pretrained(path)
|
|
206
|
+
self.load_processor().save_pretrained(path)
|
|
207
|
+
else:
|
|
208
|
+
raise ValueError(f"Unknown model type: {self.type}")
|
fusion_bench/models/__init__.py
CHANGED
|
@@ -1,10 +1,36 @@
|
|
|
1
1
|
# flake8: noqa F401
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
from .
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
2
|
+
import sys
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
6
|
+
|
|
7
|
+
from . import utils
|
|
8
|
+
|
|
9
|
+
_extra_objects = {
|
|
10
|
+
"utils": utils,
|
|
11
|
+
}
|
|
12
|
+
_import_structure = {
|
|
13
|
+
"hf_utils": [
|
|
14
|
+
"create_default_model_card",
|
|
15
|
+
"load_model_card_template",
|
|
16
|
+
"save_pretrained_with_remote_code",
|
|
17
|
+
],
|
|
18
|
+
"parameter_dict": ["ParameterDictModel"],
|
|
19
|
+
"separate_io": ["separate_load", "separate_save"],
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from .hf_utils import (
|
|
24
|
+
create_default_model_card,
|
|
25
|
+
load_model_card_template,
|
|
26
|
+
save_pretrained_with_remote_code,
|
|
27
|
+
)
|
|
28
|
+
from .parameter_dict import ParameterDictModel
|
|
29
|
+
from .separate_io import separate_load, separate_save
|
|
30
|
+
else:
|
|
31
|
+
sys.modules[__name__] = LazyImporter(
|
|
32
|
+
__name__,
|
|
33
|
+
globals()["__file__"],
|
|
34
|
+
_import_structure,
|
|
35
|
+
extra_objects=_extra_objects,
|
|
36
|
+
)
|
fusion_bench/models/hf_clip.py
CHANGED
|
@@ -195,5 +195,9 @@ class HFCLIPClassifier(nn.Module):
|
|
|
195
195
|
pass
|
|
196
196
|
elif isinstance(image_embeds, BaseModelOutputWithPooling):
|
|
197
197
|
image_embeds = image_embeds[1]
|
|
198
|
+
elif isinstance(image_embeds, dict) and "pooler_output" in image_embeds:
|
|
199
|
+
image_embeds = image_embeds["pooler_output"]
|
|
200
|
+
else:
|
|
201
|
+
raise ValueError("Unsupported output type from vision model outputs")
|
|
198
202
|
image_embeds = self.clip_model.visual_projection(image_embeds)
|
|
199
203
|
return image_embeds
|
fusion_bench/models/hf_utils.py
CHANGED
|
@@ -143,7 +143,7 @@ def save_pretrained_with_remote_code(
|
|
|
143
143
|
|
|
144
144
|
def create_default_model_card(
|
|
145
145
|
models: list[str],
|
|
146
|
-
|
|
146
|
+
base_model: Optional[str] = None,
|
|
147
147
|
title: str = "Deep Model Fusion",
|
|
148
148
|
tags: list[str] = ["fusion-bench", "merge"],
|
|
149
149
|
description=None,
|
|
@@ -154,6 +154,7 @@ def create_default_model_card(
|
|
|
154
154
|
|
|
155
155
|
template: Template = Template(load_model_card_template("default.md"))
|
|
156
156
|
card = template.render(
|
|
157
|
+
base_model=base_model,
|
|
157
158
|
models=models,
|
|
158
159
|
library_name="transformers",
|
|
159
160
|
title=title,
|
|
@@ -1,5 +1,8 @@
|
|
|
1
1
|
---
|
|
2
2
|
base_model:
|
|
3
|
+
{%- if base_model is not none %}
|
|
4
|
+
- {{ base_model }}
|
|
5
|
+
{%- endif %}
|
|
3
6
|
{%- for model in models %}
|
|
4
7
|
- {{ model }}
|
|
5
8
|
{%- endfor %}
|
|
@@ -18,7 +21,11 @@ tags:
|
|
|
18
21
|
This is a merged model created using [fusion-bench](https://github.com/tanganke/fusion_bench).
|
|
19
22
|
|
|
20
23
|
The following models were included in the merge:
|
|
21
|
-
|
|
24
|
+
|
|
25
|
+
{% if base_model is not none %}
|
|
26
|
+
- base model: {{ base_model }}
|
|
27
|
+
{%- endif %}
|
|
28
|
+
{%- for model in models %}
|
|
22
29
|
- {{ model }}
|
|
23
30
|
{%- endfor %}
|
|
24
31
|
|
|
@@ -1,10 +1,17 @@
|
|
|
1
|
-
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Callable, Dict, Generic, List, Union, cast
|
|
2
3
|
|
|
3
4
|
import numpy as np
|
|
4
5
|
import torch
|
|
6
|
+
import torch.futures
|
|
5
7
|
from omegaconf import ListConfig
|
|
6
8
|
from torch import Tensor, nn
|
|
7
9
|
|
|
10
|
+
from fusion_bench.utils.devices import to_device
|
|
11
|
+
from fusion_bench.utils.type import TorchModelType
|
|
12
|
+
|
|
13
|
+
log = logging.getLogger(__name__)
|
|
14
|
+
|
|
8
15
|
|
|
9
16
|
def aggregate_tensors(
|
|
10
17
|
outputs: List[Any], aggregate_fn: Callable
|
|
@@ -58,12 +65,16 @@ def aggregate_tensors(
|
|
|
58
65
|
raise ValueError("Unsupported type for outputs")
|
|
59
66
|
|
|
60
67
|
|
|
61
|
-
class EnsembleModule(nn.Module):
|
|
68
|
+
class EnsembleModule(nn.Module, Generic[TorchModelType]):
|
|
62
69
|
"""
|
|
63
70
|
Ensemble module that averages the outputs of multiple models.
|
|
64
71
|
"""
|
|
65
72
|
|
|
66
|
-
def __init__(
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
models: List[TorchModelType],
|
|
76
|
+
device_map: Dict[int, Union[int, str]] | None = None,
|
|
77
|
+
):
|
|
67
78
|
"""
|
|
68
79
|
Initializes the EnsembleModule with a list of models.
|
|
69
80
|
|
|
@@ -73,6 +84,16 @@ class EnsembleModule(nn.Module):
|
|
|
73
84
|
super().__init__()
|
|
74
85
|
# TODO: distribute models to devices
|
|
75
86
|
self.model_list = nn.ModuleList(models)
|
|
87
|
+
self.device_map = device_map
|
|
88
|
+
if self.device_map is not None:
|
|
89
|
+
self._move_models_to_devices()
|
|
90
|
+
|
|
91
|
+
def _move_models_to_devices(self):
|
|
92
|
+
for model_idx, device_id in self.device_map.items():
|
|
93
|
+
log.info(f"Moving model {model_idx} to device {device_id}")
|
|
94
|
+
self.model_list[model_idx] = self.model_list[model_idx].to(
|
|
95
|
+
device_id, non_blocking=True
|
|
96
|
+
)
|
|
76
97
|
|
|
77
98
|
def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
|
|
78
99
|
"""
|
|
@@ -86,6 +107,49 @@ class EnsembleModule(nn.Module):
|
|
|
86
107
|
"""
|
|
87
108
|
return torch.stack(outputs).mean(dim=0)
|
|
88
109
|
|
|
110
|
+
def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
|
|
111
|
+
"""
|
|
112
|
+
Performs parallel forward pass using device mapping with futures.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
*args: Variable length argument list.
|
|
116
|
+
**kwargs: Arbitrary keyword arguments.
|
|
117
|
+
|
|
118
|
+
Returns:
|
|
119
|
+
List[Any]: List of outputs from all models, all moved to the same device.
|
|
120
|
+
"""
|
|
121
|
+
futures = []
|
|
122
|
+
|
|
123
|
+
device_data_cache = {}
|
|
124
|
+
for i, model in enumerate(self.model_list):
|
|
125
|
+
device_id = self.device_map.get(i, "cpu")
|
|
126
|
+
|
|
127
|
+
if device_id not in device_data_cache:
|
|
128
|
+
# Move inputs to the same device as the model
|
|
129
|
+
device_args = to_device(
|
|
130
|
+
args, device_id, copy_on_move=True, non_blocking=True
|
|
131
|
+
)
|
|
132
|
+
device_kwargs = to_device(
|
|
133
|
+
kwargs, device_id, copy_on_move=True, non_blocking=True
|
|
134
|
+
)
|
|
135
|
+
device_data_cache[device_id] = (device_args, device_kwargs)
|
|
136
|
+
else:
|
|
137
|
+
device_args, device_kwargs = device_data_cache[device_id]
|
|
138
|
+
|
|
139
|
+
# Create a future for asynchronous execution
|
|
140
|
+
future = torch.jit.fork(model, *device_args, **device_kwargs)
|
|
141
|
+
futures.append(future)
|
|
142
|
+
|
|
143
|
+
# Wait for all futures to complete and collect results
|
|
144
|
+
outputs = [torch.jit.wait(future) for future in futures]
|
|
145
|
+
|
|
146
|
+
# Move all outputs to the same device (use the device of the first model or cpu as fallback)
|
|
147
|
+
target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
|
|
148
|
+
outputs = [
|
|
149
|
+
to_device(output, target_device, non_blocking=True) for output in outputs
|
|
150
|
+
]
|
|
151
|
+
return outputs
|
|
152
|
+
|
|
89
153
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
90
154
|
"""
|
|
91
155
|
Performs a forward pass by averaging the outputs of the models.
|
|
@@ -97,20 +161,25 @@ class EnsembleModule(nn.Module):
|
|
|
97
161
|
Returns:
|
|
98
162
|
Aggregated output from the ensemble of models.
|
|
99
163
|
"""
|
|
100
|
-
|
|
164
|
+
if self.device_map is None:
|
|
165
|
+
outputs = [model(*args, **kwargs) for model in self.model_list]
|
|
166
|
+
else:
|
|
167
|
+
# Parallel execution with device mapping
|
|
168
|
+
outputs = self._parallel_forward_with_device_map(*args, **kwargs)
|
|
101
169
|
return aggregate_tensors(outputs, self._aggregate_tensors)
|
|
102
170
|
|
|
103
171
|
|
|
104
|
-
class WeightedEnsembleModule(nn.Module):
|
|
172
|
+
class WeightedEnsembleModule(nn.Module, Generic[TorchModelType]):
|
|
105
173
|
"""
|
|
106
174
|
Ensemble module that computes a weighted average of the outputs from multiple models.
|
|
107
175
|
"""
|
|
108
176
|
|
|
109
177
|
def __init__(
|
|
110
178
|
self,
|
|
111
|
-
models: List[
|
|
179
|
+
models: List[TorchModelType],
|
|
112
180
|
weights: List[float] | Tensor | np.ndarray,
|
|
113
181
|
normalize: bool = True,
|
|
182
|
+
device_map: Dict[int, Union[int, str]] | None = None,
|
|
114
183
|
):
|
|
115
184
|
"""
|
|
116
185
|
Initializes the WeightedEnsembleModule with models and their corresponding weights.
|
|
@@ -119,9 +188,12 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
119
188
|
models (List[nn.Module]): List of models to ensemble.
|
|
120
189
|
weights (List[float] | Tensor | np.ndarray): Weights for each model.
|
|
121
190
|
normalize (bool, optional): If True, normalizes the weights. Defaults to True.
|
|
191
|
+
device_map (Dict[int, Union[int, str]] | None, optional): Device mapping for parallel execution. Defaults to None.
|
|
122
192
|
"""
|
|
123
193
|
super().__init__()
|
|
124
194
|
self.model_list = nn.ModuleList(models)
|
|
195
|
+
self.device_map = device_map
|
|
196
|
+
|
|
125
197
|
if isinstance(weights, (list, tuple, ListConfig)):
|
|
126
198
|
weights = torch.tensor(weights)
|
|
127
199
|
elif isinstance(weights, Tensor):
|
|
@@ -139,6 +211,17 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
139
211
|
weights = weights / weights.sum()
|
|
140
212
|
self.register_buffer("weights", weights)
|
|
141
213
|
|
|
214
|
+
if self.device_map is not None:
|
|
215
|
+
self._move_models_to_devices()
|
|
216
|
+
|
|
217
|
+
def _move_models_to_devices(self):
|
|
218
|
+
"""Move models to their assigned devices according to device_map."""
|
|
219
|
+
for model_idx, device_id in self.device_map.items():
|
|
220
|
+
log.info(f"Moving model {model_idx} to device {device_id}")
|
|
221
|
+
self.model_list[model_idx] = self.model_list[model_idx].to(
|
|
222
|
+
device_id, non_blocking=True
|
|
223
|
+
)
|
|
224
|
+
|
|
142
225
|
def _aggregate_tensors(self, outputs: List[Tensor]) -> Tensor:
|
|
143
226
|
"""
|
|
144
227
|
Aggregates a list of tensors using the provided weights.
|
|
@@ -152,6 +235,48 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
152
235
|
weights = cast(Tensor, self.weights).view(-1, *([1] * outputs[0].dim()))
|
|
153
236
|
return (torch.stack(outputs) * weights).sum(dim=0)
|
|
154
237
|
|
|
238
|
+
def _parallel_forward_with_device_map(self, *args: Any, **kwargs: Any) -> List[Any]:
|
|
239
|
+
"""
|
|
240
|
+
Performs parallel forward pass using device mapping with futures.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
*args: Variable length argument list.
|
|
244
|
+
**kwargs: Arbitrary keyword arguments.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
List[Any]: List of outputs from all models, all moved to the same device.
|
|
248
|
+
"""
|
|
249
|
+
futures = []
|
|
250
|
+
|
|
251
|
+
device_data_cache = {}
|
|
252
|
+
for i, model in enumerate(self.model_list):
|
|
253
|
+
device_id = self.device_map.get(i, "cpu")
|
|
254
|
+
|
|
255
|
+
if device_id not in device_data_cache:
|
|
256
|
+
# Move inputs to the same device as the model
|
|
257
|
+
device_args = to_device(
|
|
258
|
+
args, device_id, copy_on_move=True, non_blocking=True
|
|
259
|
+
)
|
|
260
|
+
device_kwargs = to_device(
|
|
261
|
+
kwargs, device_id, copy_on_move=True, non_blocking=True
|
|
262
|
+
)
|
|
263
|
+
device_data_cache[device_id] = (device_args, device_kwargs)
|
|
264
|
+
else:
|
|
265
|
+
device_args, device_kwargs = device_data_cache[device_id]
|
|
266
|
+
|
|
267
|
+
# Create a future for asynchronous execution
|
|
268
|
+
future = torch.jit.fork(model, *device_args, **device_kwargs)
|
|
269
|
+
futures.append(future)
|
|
270
|
+
|
|
271
|
+
# Wait for all futures to complete and collect results
|
|
272
|
+
outputs = [torch.jit.wait(future) for future in futures]
|
|
273
|
+
|
|
274
|
+
# Move all outputs to the same device (use the device of the first model or cpu as fallback)
|
|
275
|
+
target_device = self.device_map.get(0, "cpu") if self.device_map else "cpu"
|
|
276
|
+
outputs = [to_device(output, target_device) for output in outputs]
|
|
277
|
+
|
|
278
|
+
return outputs
|
|
279
|
+
|
|
155
280
|
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
156
281
|
"""
|
|
157
282
|
Performs a forward pass by computing the weighted average of the models' outputs.
|
|
@@ -163,7 +288,11 @@ class WeightedEnsembleModule(nn.Module):
|
|
|
163
288
|
Returns:
|
|
164
289
|
Weighted aggregated output from the ensemble of models.
|
|
165
290
|
"""
|
|
166
|
-
|
|
291
|
+
if self.device_map is None:
|
|
292
|
+
outputs = [model(*args, **kwargs) for model in self.model_list]
|
|
293
|
+
else:
|
|
294
|
+
# Parallel execution with device mapping
|
|
295
|
+
outputs = self._parallel_forward_with_device_map(*args, **kwargs)
|
|
167
296
|
return aggregate_tensors(outputs, self._aggregate_tensors)
|
|
168
297
|
|
|
169
298
|
|
fusion_bench/optim/__init__.py
CHANGED
|
@@ -1,2 +1,40 @@
|
|
|
1
|
-
|
|
2
|
-
from
|
|
1
|
+
import sys
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
5
|
+
|
|
6
|
+
from . import lr_scheduler
|
|
7
|
+
|
|
8
|
+
_extra_objects = {
|
|
9
|
+
"lr_scheduler": lr_scheduler,
|
|
10
|
+
}
|
|
11
|
+
_import_structure = {
|
|
12
|
+
"exception": [
|
|
13
|
+
"NoClosureError",
|
|
14
|
+
"NoSparseGradientError",
|
|
15
|
+
"NegativeLRError",
|
|
16
|
+
"NegativeStepError",
|
|
17
|
+
"ZeroParameterSizeError",
|
|
18
|
+
],
|
|
19
|
+
"mezo": ["MeZO"],
|
|
20
|
+
"muon": ["Muon"],
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from .exception import (
|
|
25
|
+
NegativeLRError,
|
|
26
|
+
NegativeStepError,
|
|
27
|
+
NoClosureError,
|
|
28
|
+
NoSparseGradientError,
|
|
29
|
+
ZeroParameterSizeError,
|
|
30
|
+
)
|
|
31
|
+
from .mezo import MeZO
|
|
32
|
+
from .muon import Muon
|
|
33
|
+
|
|
34
|
+
else:
|
|
35
|
+
sys.modules[__name__] = LazyImporter(
|
|
36
|
+
__name__,
|
|
37
|
+
globals()["__file__"],
|
|
38
|
+
_import_structure,
|
|
39
|
+
extra_objects=_extra_objects,
|
|
40
|
+
)
|
|
@@ -1 +1,27 @@
|
|
|
1
|
-
|
|
1
|
+
import sys
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from fusion_bench.utils.lazy_imports import LazyImporter
|
|
5
|
+
|
|
6
|
+
_import_structure = {
|
|
7
|
+
"linear_warmup": [
|
|
8
|
+
"BaseLinearWarmupScheduler",
|
|
9
|
+
"LinearWarmupScheduler",
|
|
10
|
+
"CosineDecayWithWarmup",
|
|
11
|
+
"PolySchedulerWithWarmup",
|
|
12
|
+
],
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from .linear_warmup import (
|
|
17
|
+
BaseLinearWarmupScheduler,
|
|
18
|
+
CosineDecayWithWarmup,
|
|
19
|
+
LinearWarmupScheduler,
|
|
20
|
+
PolySchedulerWithWarmup,
|
|
21
|
+
)
|
|
22
|
+
else:
|
|
23
|
+
sys.modules[__name__] = LazyImporter(
|
|
24
|
+
__name__,
|
|
25
|
+
globals()["__file__"],
|
|
26
|
+
_import_structure,
|
|
27
|
+
)
|