fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fusion_bench/__init__.py +152 -42
- fusion_bench/dataset/__init__.py +27 -4
- fusion_bench/dataset/clip_dataset.py +2 -2
- fusion_bench/method/__init__.py +18 -1
- fusion_bench/method/classification/__init__.py +27 -2
- fusion_bench/method/classification/image_classification_finetune.py +214 -0
- fusion_bench/method/ensemble.py +17 -2
- fusion_bench/method/linear/__init__.py +6 -2
- fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
- fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
- fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
- fusion_bench/method/opcm/opcm.py +1 -0
- fusion_bench/method/pwe_moe/module.py +0 -2
- fusion_bench/method/simple_average.py +2 -2
- fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
- fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
- fusion_bench/method/ties_merging/ties_merging.py +22 -6
- fusion_bench/method/wudi/__init__.py +1 -0
- fusion_bench/method/wudi/wudi.py +105 -0
- fusion_bench/mixins/__init__.py +2 -0
- fusion_bench/mixins/lightning_fabric.py +4 -0
- fusion_bench/mixins/pyinstrument.py +174 -0
- fusion_bench/mixins/serialization.py +25 -78
- fusion_bench/mixins/simple_profiler.py +106 -23
- fusion_bench/modelpool/__init__.py +2 -0
- fusion_bench/modelpool/base_pool.py +77 -14
- fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
- fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
- fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
- fusion_bench/models/__init__.py +35 -9
- fusion_bench/models/hf_clip.py +4 -0
- fusion_bench/models/hf_utils.py +2 -1
- fusion_bench/models/model_card_templates/default.md +8 -1
- fusion_bench/models/wrappers/ensemble.py +136 -7
- fusion_bench/optim/__init__.py +40 -2
- fusion_bench/optim/lr_scheduler/__init__.py +27 -1
- fusion_bench/optim/muon.py +339 -0
- fusion_bench/programs/__init__.py +2 -0
- fusion_bench/programs/fabric_fusion_program.py +2 -2
- fusion_bench/programs/fusion_program.py +271 -0
- fusion_bench/scripts/cli.py +2 -2
- fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
- fusion_bench/tasks/clip_classification/__init__.py +15 -0
- fusion_bench/utils/__init__.py +167 -21
- fusion_bench/utils/devices.py +30 -8
- fusion_bench/utils/lazy_imports.py +91 -12
- fusion_bench/utils/lazy_state_dict.py +58 -5
- fusion_bench/utils/misc.py +104 -13
- fusion_bench/utils/packages.py +4 -0
- fusion_bench/utils/path.py +7 -0
- fusion_bench/utils/pylogger.py +6 -0
- fusion_bench/utils/rich_utils.py +8 -3
- fusion_bench/utils/state_dict_arithmetic.py +935 -162
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
- fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
- fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
- fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
- fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
- fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
- fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
- fusion_bench_config/method/wudi/wudi.yaml +4 -0
- fusion_bench_config/model_fusion.yaml +45 -0
- fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
- fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
- fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
- fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
- fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,339 @@
|
|
|
1
|
+
# Copied from https://github.com/KellerJordan/Muon
|
|
2
|
+
import torch
|
|
3
|
+
import torch.distributed as dist
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def zeropower_via_newtonschulz5(G, steps: int):
|
|
7
|
+
"""
|
|
8
|
+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
|
9
|
+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
|
10
|
+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
|
11
|
+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
|
12
|
+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
|
13
|
+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
|
14
|
+
performance at all relative to UV^T, where USV^T = G is the SVD.
|
|
15
|
+
"""
|
|
16
|
+
assert (
|
|
17
|
+
G.ndim >= 2
|
|
18
|
+
) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
|
19
|
+
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
20
|
+
X = G.bfloat16()
|
|
21
|
+
if G.size(-2) > G.size(-1):
|
|
22
|
+
X = X.mT
|
|
23
|
+
|
|
24
|
+
# Ensure spectral norm is at most 1
|
|
25
|
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
|
|
26
|
+
# Perform the NS iterations
|
|
27
|
+
for _ in range(steps):
|
|
28
|
+
A = X @ X.mT
|
|
29
|
+
B = (
|
|
30
|
+
b * A + c * A @ A
|
|
31
|
+
) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
|
32
|
+
X = a * X + B @ X
|
|
33
|
+
|
|
34
|
+
if G.size(-2) > G.size(-1):
|
|
35
|
+
X = X.mT
|
|
36
|
+
return X
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
|
40
|
+
momentum.lerp_(grad, 1 - beta)
|
|
41
|
+
update = grad.lerp_(momentum, beta) if nesterov else momentum
|
|
42
|
+
if update.ndim == 4: # for the case of conv filters
|
|
43
|
+
update = update.view(len(update), -1)
|
|
44
|
+
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
|
45
|
+
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
|
|
46
|
+
return update
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Muon(torch.optim.Optimizer):
|
|
50
|
+
"""
|
|
51
|
+
Muon - MomentUm Orthogonalized by Newton-schulz
|
|
52
|
+
|
|
53
|
+
https://kellerjordan.github.io/posts/muon/
|
|
54
|
+
|
|
55
|
+
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
|
56
|
+
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
|
57
|
+
matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
|
|
58
|
+
advantage that it can be stably run in bfloat16 on the GPU.
|
|
59
|
+
|
|
60
|
+
Muon should only be used for hidden weight layers. The input embedding, final output layer,
|
|
61
|
+
and any internal gains or biases should be optimized using a standard method such as AdamW.
|
|
62
|
+
Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
|
|
63
|
+
collapsing their last 3 dimensions.
|
|
64
|
+
|
|
65
|
+
Arguments:
|
|
66
|
+
lr: The learning rate, in units of spectral norm per update.
|
|
67
|
+
weight_decay: The AdamW-style weight decay.
|
|
68
|
+
momentum: The momentum. A value of 0.95 here is usually fine.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
|
72
|
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
|
73
|
+
assert (
|
|
74
|
+
isinstance(params, list)
|
|
75
|
+
and len(params) >= 1
|
|
76
|
+
and isinstance(params[0], torch.nn.Parameter)
|
|
77
|
+
)
|
|
78
|
+
params = sorted(params, key=lambda x: x.size(), reverse=True)
|
|
79
|
+
super().__init__(params, defaults)
|
|
80
|
+
|
|
81
|
+
@torch.no_grad()
|
|
82
|
+
def step(self, closure=None):
|
|
83
|
+
|
|
84
|
+
loss = None
|
|
85
|
+
if closure is not None:
|
|
86
|
+
with torch.enable_grad():
|
|
87
|
+
loss = closure()
|
|
88
|
+
|
|
89
|
+
for group in self.param_groups:
|
|
90
|
+
params = group["params"]
|
|
91
|
+
params_pad = params + [torch.empty_like(params[-1])] * (
|
|
92
|
+
dist.get_world_size() - len(params) % dist.get_world_size()
|
|
93
|
+
)
|
|
94
|
+
for base_i in range(len(params))[:: dist.get_world_size()]:
|
|
95
|
+
if base_i + dist.get_rank() < len(params):
|
|
96
|
+
p = params[base_i + dist.get_rank()]
|
|
97
|
+
if p.grad is None:
|
|
98
|
+
# continue
|
|
99
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
100
|
+
state = self.state[p]
|
|
101
|
+
if len(state) == 0:
|
|
102
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
103
|
+
update = muon_update(
|
|
104
|
+
p.grad, state["momentum_buffer"], beta=group["momentum"]
|
|
105
|
+
)
|
|
106
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
107
|
+
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
|
108
|
+
dist.all_gather(
|
|
109
|
+
params_pad[base_i : base_i + dist.get_world_size()],
|
|
110
|
+
params_pad[base_i + dist.get_rank()],
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return loss
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SingleDeviceMuon(torch.optim.Optimizer):
|
|
117
|
+
"""
|
|
118
|
+
Muon variant for usage in non-distributed settings.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
|
122
|
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
|
123
|
+
super().__init__(params, defaults)
|
|
124
|
+
|
|
125
|
+
@torch.no_grad()
|
|
126
|
+
def step(self, closure=None):
|
|
127
|
+
|
|
128
|
+
loss = None
|
|
129
|
+
if closure is not None:
|
|
130
|
+
with torch.enable_grad():
|
|
131
|
+
loss = closure()
|
|
132
|
+
|
|
133
|
+
for group in self.param_groups:
|
|
134
|
+
for p in group["params"]:
|
|
135
|
+
if p.grad is None:
|
|
136
|
+
# continue
|
|
137
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
138
|
+
state = self.state[p]
|
|
139
|
+
if len(state) == 0:
|
|
140
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
141
|
+
update = muon_update(
|
|
142
|
+
p.grad, state["momentum_buffer"], beta=group["momentum"]
|
|
143
|
+
)
|
|
144
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
145
|
+
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
|
146
|
+
|
|
147
|
+
return loss
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def adam_update(grad, buf1, buf2, step, betas, eps):
|
|
151
|
+
buf1.lerp_(grad, 1 - betas[0])
|
|
152
|
+
buf2.lerp_(grad.square(), 1 - betas[1])
|
|
153
|
+
buf1c = buf1 / (1 - betas[0] ** step)
|
|
154
|
+
buf2c = buf2 / (1 - betas[1] ** step)
|
|
155
|
+
return buf1c / (buf2c.sqrt() + eps)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class MuonWithAuxAdam(torch.optim.Optimizer):
|
|
159
|
+
"""
|
|
160
|
+
Distributed Muon variant that can be used for all parameters in the network, since it runs an
|
|
161
|
+
internal AdamW for the parameters that are not compatible with Muon. The user must manually
|
|
162
|
+
specify which parameters shall be optimized with Muon and which with Adam by passing in a
|
|
163
|
+
list of param_groups with the `use_muon` flag set.
|
|
164
|
+
|
|
165
|
+
The point of this class is to allow the user to have a single optimizer in their code, rather
|
|
166
|
+
than having both a Muon and an Adam which each need to be stepped.
|
|
167
|
+
|
|
168
|
+
You can see an example usage below:
|
|
169
|
+
|
|
170
|
+
https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
|
|
171
|
+
```
|
|
172
|
+
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
|
|
173
|
+
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
|
174
|
+
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
|
175
|
+
head_params = [model.lm_head.weight]
|
|
176
|
+
|
|
177
|
+
from muon import MuonWithAuxAdam
|
|
178
|
+
adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
|
|
179
|
+
adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
|
|
180
|
+
muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
|
|
181
|
+
param_groups = [*adam_groups, muon_group]
|
|
182
|
+
optimizer = MuonWithAuxAdam(param_groups)
|
|
183
|
+
```
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(self, param_groups):
|
|
187
|
+
for group in param_groups:
|
|
188
|
+
assert "use_muon" in group
|
|
189
|
+
if group["use_muon"]:
|
|
190
|
+
group["params"] = sorted(
|
|
191
|
+
group["params"], key=lambda x: x.size(), reverse=True
|
|
192
|
+
)
|
|
193
|
+
# defaults
|
|
194
|
+
group["lr"] = group.get("lr", 0.02)
|
|
195
|
+
group["momentum"] = group.get("momentum", 0.95)
|
|
196
|
+
group["weight_decay"] = group.get("weight_decay", 0)
|
|
197
|
+
assert set(group.keys()) == set(
|
|
198
|
+
["params", "lr", "momentum", "weight_decay", "use_muon"]
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
# defaults
|
|
202
|
+
group["lr"] = group.get("lr", 3e-4)
|
|
203
|
+
group["betas"] = group.get("betas", (0.9, 0.95))
|
|
204
|
+
group["eps"] = group.get("eps", 1e-10)
|
|
205
|
+
group["weight_decay"] = group.get("weight_decay", 0)
|
|
206
|
+
assert set(group.keys()) == set(
|
|
207
|
+
["params", "lr", "betas", "eps", "weight_decay", "use_muon"]
|
|
208
|
+
)
|
|
209
|
+
super().__init__(param_groups, dict())
|
|
210
|
+
|
|
211
|
+
@torch.no_grad()
|
|
212
|
+
def step(self, closure=None):
|
|
213
|
+
|
|
214
|
+
loss = None
|
|
215
|
+
if closure is not None:
|
|
216
|
+
with torch.enable_grad():
|
|
217
|
+
loss = closure()
|
|
218
|
+
|
|
219
|
+
for group in self.param_groups:
|
|
220
|
+
if group["use_muon"]:
|
|
221
|
+
params = group["params"]
|
|
222
|
+
params_pad = params + [torch.empty_like(params[-1])] * (
|
|
223
|
+
dist.get_world_size() - len(params) % dist.get_world_size()
|
|
224
|
+
)
|
|
225
|
+
for base_i in range(len(params))[:: dist.get_world_size()]:
|
|
226
|
+
if base_i + dist.get_rank() < len(params):
|
|
227
|
+
p = params[base_i + dist.get_rank()]
|
|
228
|
+
if p.grad is None:
|
|
229
|
+
# continue
|
|
230
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
231
|
+
state = self.state[p]
|
|
232
|
+
if len(state) == 0:
|
|
233
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
234
|
+
update = muon_update(
|
|
235
|
+
p.grad, state["momentum_buffer"], beta=group["momentum"]
|
|
236
|
+
)
|
|
237
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
238
|
+
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
|
239
|
+
dist.all_gather(
|
|
240
|
+
params_pad[base_i : base_i + dist.get_world_size()],
|
|
241
|
+
params_pad[base_i + dist.get_rank()],
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
for p in group["params"]:
|
|
245
|
+
if p.grad is None:
|
|
246
|
+
# continue
|
|
247
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
248
|
+
state = self.state[p]
|
|
249
|
+
if len(state) == 0:
|
|
250
|
+
state["exp_avg"] = torch.zeros_like(p)
|
|
251
|
+
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
252
|
+
state["step"] = 0
|
|
253
|
+
state["step"] += 1
|
|
254
|
+
update = adam_update(
|
|
255
|
+
p.grad,
|
|
256
|
+
state["exp_avg"],
|
|
257
|
+
state["exp_avg_sq"],
|
|
258
|
+
state["step"],
|
|
259
|
+
group["betas"],
|
|
260
|
+
group["eps"],
|
|
261
|
+
)
|
|
262
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
263
|
+
p.add_(update, alpha=-group["lr"])
|
|
264
|
+
|
|
265
|
+
return loss
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
|
269
|
+
"""
|
|
270
|
+
Non-distributed variant of MuonWithAuxAdam.
|
|
271
|
+
"""
|
|
272
|
+
|
|
273
|
+
def __init__(self, param_groups):
|
|
274
|
+
for group in param_groups:
|
|
275
|
+
assert "use_muon" in group
|
|
276
|
+
if group["use_muon"]:
|
|
277
|
+
# defaults
|
|
278
|
+
group["lr"] = group.get("lr", 0.02)
|
|
279
|
+
group["momentum"] = group.get("momentum", 0.95)
|
|
280
|
+
group["weight_decay"] = group.get("weight_decay", 0)
|
|
281
|
+
assert set(group.keys()) == set(
|
|
282
|
+
["params", "lr", "momentum", "weight_decay", "use_muon"]
|
|
283
|
+
)
|
|
284
|
+
else:
|
|
285
|
+
# defaults
|
|
286
|
+
group["lr"] = group.get("lr", 3e-4)
|
|
287
|
+
group["betas"] = group.get("betas", (0.9, 0.95))
|
|
288
|
+
group["eps"] = group.get("eps", 1e-10)
|
|
289
|
+
group["weight_decay"] = group.get("weight_decay", 0)
|
|
290
|
+
assert set(group.keys()) == set(
|
|
291
|
+
["params", "lr", "betas", "eps", "weight_decay", "use_muon"]
|
|
292
|
+
)
|
|
293
|
+
super().__init__(param_groups, dict())
|
|
294
|
+
|
|
295
|
+
@torch.no_grad()
|
|
296
|
+
def step(self, closure=None):
|
|
297
|
+
|
|
298
|
+
loss = None
|
|
299
|
+
if closure is not None:
|
|
300
|
+
with torch.enable_grad():
|
|
301
|
+
loss = closure()
|
|
302
|
+
|
|
303
|
+
for group in self.param_groups:
|
|
304
|
+
if group["use_muon"]:
|
|
305
|
+
for p in group["params"]:
|
|
306
|
+
if p.grad is None:
|
|
307
|
+
# continue
|
|
308
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
309
|
+
state = self.state[p]
|
|
310
|
+
if len(state) == 0:
|
|
311
|
+
state["momentum_buffer"] = torch.zeros_like(p)
|
|
312
|
+
update = muon_update(
|
|
313
|
+
p.grad, state["momentum_buffer"], beta=group["momentum"]
|
|
314
|
+
)
|
|
315
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
316
|
+
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
|
317
|
+
else:
|
|
318
|
+
for p in group["params"]:
|
|
319
|
+
if p.grad is None:
|
|
320
|
+
# continue
|
|
321
|
+
p.grad = torch.zeros_like(p) # Force synchronization
|
|
322
|
+
state = self.state[p]
|
|
323
|
+
if len(state) == 0:
|
|
324
|
+
state["exp_avg"] = torch.zeros_like(p)
|
|
325
|
+
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
326
|
+
state["step"] = 0
|
|
327
|
+
state["step"] += 1
|
|
328
|
+
update = adam_update(
|
|
329
|
+
p.grad,
|
|
330
|
+
state["exp_avg"],
|
|
331
|
+
state["exp_avg_sq"],
|
|
332
|
+
state["step"],
|
|
333
|
+
group["betas"],
|
|
334
|
+
group["eps"],
|
|
335
|
+
)
|
|
336
|
+
p.mul_(1 - group["lr"] * group["weight_decay"])
|
|
337
|
+
p.add_(update, alpha=-group["lr"])
|
|
338
|
+
|
|
339
|
+
return loss
|
|
@@ -7,11 +7,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
|
|
|
7
7
|
_import_structure = {
|
|
8
8
|
"base_program": ["BaseHydraProgram"],
|
|
9
9
|
"fabric_fusion_program": ["FabricModelFusionProgram"],
|
|
10
|
+
"fusion_program": ["ModelFusionProgram"],
|
|
10
11
|
}
|
|
11
12
|
|
|
12
13
|
if TYPE_CHECKING:
|
|
13
14
|
from .base_program import BaseHydraProgram
|
|
14
15
|
from .fabric_fusion_program import FabricModelFusionProgram
|
|
16
|
+
from .fusion_program import ModelFusionProgram
|
|
15
17
|
else:
|
|
16
18
|
sys.modules[__name__] = LazyImporter(
|
|
17
19
|
__name__,
|
|
@@ -5,7 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
|
|
6
6
|
|
|
7
7
|
import lightning as L
|
|
8
|
-
from
|
|
8
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
9
9
|
from omegaconf import DictConfig, OmegaConf
|
|
10
10
|
from torch import nn
|
|
11
11
|
from tqdm.auto import tqdm
|
|
@@ -236,7 +236,7 @@ class FabricModelFusionProgram(
|
|
|
236
236
|
|
|
237
237
|
# create symbol link to hydra output directory
|
|
238
238
|
if (
|
|
239
|
-
|
|
239
|
+
rank_zero_only.rank == 0
|
|
240
240
|
and self.log_dir is not None
|
|
241
241
|
and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
|
|
242
242
|
):
|
|
@@ -0,0 +1,271 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, Iterable, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import lightning as L
|
|
6
|
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
|
7
|
+
from omegaconf import DictConfig, OmegaConf
|
|
8
|
+
from torch import nn
|
|
9
|
+
from tqdm.auto import tqdm
|
|
10
|
+
|
|
11
|
+
from fusion_bench import (
|
|
12
|
+
BaseAlgorithm,
|
|
13
|
+
BaseHydraProgram,
|
|
14
|
+
BaseModelPool,
|
|
15
|
+
BaseTaskPool,
|
|
16
|
+
RuntimeConstants,
|
|
17
|
+
auto_register_config,
|
|
18
|
+
get_rankzero_logger,
|
|
19
|
+
import_object,
|
|
20
|
+
instantiate,
|
|
21
|
+
timeit_context,
|
|
22
|
+
)
|
|
23
|
+
from fusion_bench.utils.json import print_json
|
|
24
|
+
from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
|
|
25
|
+
|
|
26
|
+
log = get_rankzero_logger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@auto_register_config
|
|
30
|
+
class ModelFusionProgram(BaseHydraProgram):
|
|
31
|
+
method: BaseAlgorithm
|
|
32
|
+
modelpool: BaseModelPool
|
|
33
|
+
taskpool: Optional[BaseTaskPool] = None
|
|
34
|
+
|
|
35
|
+
_config_mapping = BaseHydraProgram._config_mapping | {
|
|
36
|
+
"_method": "method",
|
|
37
|
+
"_modelpool": "modelpool",
|
|
38
|
+
"_taskpool": "taskpool",
|
|
39
|
+
"fast_dev_run": "fast_dev_run",
|
|
40
|
+
"seed": "seed",
|
|
41
|
+
"path": "path",
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
method: DictConfig,
|
|
47
|
+
modelpool: DictConfig,
|
|
48
|
+
taskpool: Optional[DictConfig] = None,
|
|
49
|
+
*,
|
|
50
|
+
print_config: bool = True,
|
|
51
|
+
dry_run: bool = False,
|
|
52
|
+
report_save_path: Optional[str] = None,
|
|
53
|
+
merged_model_save_path: Optional[str] = None,
|
|
54
|
+
merged_model_save_kwargs: Optional[DictConfig] = None,
|
|
55
|
+
fast_dev_run: bool = False,
|
|
56
|
+
seed: Optional[int] = None,
|
|
57
|
+
print_function_call: bool = True,
|
|
58
|
+
path: DictConfig = None,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
super().__init__(**kwargs)
|
|
62
|
+
self._method = method
|
|
63
|
+
self._modelpool = modelpool
|
|
64
|
+
self._taskpool = taskpool
|
|
65
|
+
self.report_save_path = report_save_path
|
|
66
|
+
self.merged_model_save_path = merged_model_save_path
|
|
67
|
+
self.merged_model_save_kwargs = merged_model_save_kwargs
|
|
68
|
+
self.fast_dev_run = fast_dev_run
|
|
69
|
+
self.seed = seed
|
|
70
|
+
self.path = path
|
|
71
|
+
RuntimeConstants.debug = fast_dev_run
|
|
72
|
+
RuntimeConstants.print_function_call = print_function_call
|
|
73
|
+
if path is not None:
|
|
74
|
+
RuntimeConstants.cache_dir = path.get("cache_dir", None)
|
|
75
|
+
|
|
76
|
+
if print_config:
|
|
77
|
+
print_config_tree(
|
|
78
|
+
self.config,
|
|
79
|
+
print_order=["method", "modelpool", "taskpool"],
|
|
80
|
+
)
|
|
81
|
+
if dry_run:
|
|
82
|
+
log.info("The program is running in dry-run mode. Exiting.")
|
|
83
|
+
exit(0)
|
|
84
|
+
|
|
85
|
+
def _instantiate_and_setup(
|
|
86
|
+
self, config: DictConfig, compat_load_fn: Optional[str] = None
|
|
87
|
+
):
|
|
88
|
+
R"""
|
|
89
|
+
Instantiates and sets up an object based on the provided configuration.
|
|
90
|
+
|
|
91
|
+
This method performs the following steps:
|
|
92
|
+
1. Checks if the configuration dictionary contains the key "_target_".
|
|
93
|
+
2. If "_target_" is not found (for v0.1.x), attempts to instantiate the object using a compatible load function if provided.
|
|
94
|
+
- Logs a warning if "_target_" is missing.
|
|
95
|
+
- If `compat_load_fn` is provided, imports the function and uses it to instantiate the object.
|
|
96
|
+
- If `compat_load_fn` is not provided, raises a ValueError.
|
|
97
|
+
3. If "_target_" is found (for v.0.2.0 and above), attempts to import and instantiate the object using the `instantiate` function.
|
|
98
|
+
- Ensures the target can be imported.
|
|
99
|
+
- Uses the `instantiate` function with `_recursive_` set based on the configuration.
|
|
100
|
+
4. Sets the `_program` attribute of the instantiated object to `self` if the object has this attribute.
|
|
101
|
+
5. Returns the instantiated and set up object.
|
|
102
|
+
"""
|
|
103
|
+
if "_target_" not in config:
|
|
104
|
+
log.warning(
|
|
105
|
+
"No '_target_' key found in config. Attempting to instantiate the object in a compatible way."
|
|
106
|
+
)
|
|
107
|
+
if compat_load_fn is not None:
|
|
108
|
+
compat_load_fn = import_object(compat_load_fn)
|
|
109
|
+
if rank_zero_only.rank == 0:
|
|
110
|
+
print_bordered(
|
|
111
|
+
OmegaConf.to_yaml(config),
|
|
112
|
+
title="instantiate compat object",
|
|
113
|
+
style="magenta",
|
|
114
|
+
code_style="yaml",
|
|
115
|
+
)
|
|
116
|
+
obj = compat_load_fn(config)
|
|
117
|
+
else:
|
|
118
|
+
raise ValueError(
|
|
119
|
+
"No load function provided. Please provide a load function to instantiate the object."
|
|
120
|
+
)
|
|
121
|
+
else:
|
|
122
|
+
# try to import the object from the target
|
|
123
|
+
# this checks if the target is valid and can be imported
|
|
124
|
+
import_object(config._target_)
|
|
125
|
+
obj = instantiate(
|
|
126
|
+
config,
|
|
127
|
+
_recursive_=config.get("_recursive_", False),
|
|
128
|
+
)
|
|
129
|
+
if hasattr(obj, "_program"):
|
|
130
|
+
obj._program = self
|
|
131
|
+
return obj
|
|
132
|
+
|
|
133
|
+
def save_merged_model(self, merged_model):
|
|
134
|
+
"""
|
|
135
|
+
Saves the merged model to the specified path.
|
|
136
|
+
"""
|
|
137
|
+
if self.merged_model_save_path is not None:
|
|
138
|
+
# path to save the merged model, use "{log_dir}" to refer to the logger directory
|
|
139
|
+
save_path: str = self.merged_model_save_path
|
|
140
|
+
if "{log_dir}" in save_path and self.log_dir is not None:
|
|
141
|
+
save_path = save_path.format(log_dir=self.log_dir)
|
|
142
|
+
|
|
143
|
+
if os.path.dirname(save_path):
|
|
144
|
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
|
145
|
+
|
|
146
|
+
# save the merged model
|
|
147
|
+
if self.merged_model_save_kwargs is not None:
|
|
148
|
+
merged_model_save_kwargs = self.merged_model_save_kwargs
|
|
149
|
+
else:
|
|
150
|
+
merged_model_save_kwargs = {}
|
|
151
|
+
with timeit_context(f"Saving the merged model to {save_path}"):
|
|
152
|
+
self.modelpool.save_model(
|
|
153
|
+
merged_model,
|
|
154
|
+
save_path,
|
|
155
|
+
**merged_model_save_kwargs,
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
print("No save path specified for the merged model. Skipping saving.")
|
|
159
|
+
|
|
160
|
+
def evaluate_merged_model(
|
|
161
|
+
self,
|
|
162
|
+
taskpool: BaseTaskPool,
|
|
163
|
+
merged_model: Union[nn.Module, Dict, Iterable],
|
|
164
|
+
*args: Any,
|
|
165
|
+
**kwargs: Any,
|
|
166
|
+
) -> Union[Dict, List, Any]:
|
|
167
|
+
"""
|
|
168
|
+
Evaluates the merged model using the provided task pool.
|
|
169
|
+
|
|
170
|
+
Depending on the type of the merged model, this function handles the evaluation differently:
|
|
171
|
+
- If the merged model is an instance of `nn.Module`, it directly evaluates the model.
|
|
172
|
+
- If the merged model is a dictionary, it extracts the model from the dictionary and evaluates it.
|
|
173
|
+
The evaluation report is then updated with the remaining dictionary items.
|
|
174
|
+
- If the merged model is an iterable, it recursively evaluates each model in the iterable.
|
|
175
|
+
- Raises a `ValueError` if the merged model is of an invalid type.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
taskpool: The task pool used for evaluating the merged model.
|
|
179
|
+
merged_model: The merged model to be evaluated. It can be an instance of `nn.Module`, a dictionary, or an iterable.
|
|
180
|
+
*args: Additional positional arguments to be passed to the `evaluate` method of the taskpool.
|
|
181
|
+
**kwargs: Additional keyword arguments to be passed to the `evaluate` method of the taskpool.
|
|
182
|
+
|
|
183
|
+
Returns:
|
|
184
|
+
The evaluation report. The type of the report depends on the type of the merged model:
|
|
185
|
+
- If the merged model is an instance of `nn.Module`, the report is a dictionary.
|
|
186
|
+
- If the merged model is a dictionary, the report is a dictionary updated with the remaining dictionary items.
|
|
187
|
+
- If the merged model is an iterable, the report is a list of evaluation reports.
|
|
188
|
+
"""
|
|
189
|
+
if isinstance(merged_model, nn.Module):
|
|
190
|
+
report = taskpool.evaluate(merged_model, *args, **kwargs)
|
|
191
|
+
return report
|
|
192
|
+
elif isinstance(merged_model, Dict):
|
|
193
|
+
report = {}
|
|
194
|
+
for key, item in merged_model.items():
|
|
195
|
+
if isinstance(item, nn.Module):
|
|
196
|
+
report[key] = taskpool.evaluate(item, *args, **kwargs)
|
|
197
|
+
elif key == "models":
|
|
198
|
+
# for multi-model evaluation
|
|
199
|
+
report[key] = self.evaluate_merged_model(
|
|
200
|
+
taskpool, item, *args, **kwargs
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
# metadata
|
|
204
|
+
report[key] = item
|
|
205
|
+
return report
|
|
206
|
+
elif isinstance(merged_model, Iterable):
|
|
207
|
+
return [
|
|
208
|
+
self.evaluate_merged_model(taskpool, m, *args, **kwargs)
|
|
209
|
+
for m in tqdm(merged_model, desc="Evaluating models")
|
|
210
|
+
]
|
|
211
|
+
else:
|
|
212
|
+
raise ValueError(f"Invalid type for merged model: {type(merged_model)}")
|
|
213
|
+
|
|
214
|
+
def run(self):
|
|
215
|
+
"""
|
|
216
|
+
Executes the model fusion program.
|
|
217
|
+
"""
|
|
218
|
+
if self.seed is not None:
|
|
219
|
+
L.seed_everything(self.seed)
|
|
220
|
+
|
|
221
|
+
log.info("Running the model fusion program.")
|
|
222
|
+
# setup the modelpool, method, and taskpool
|
|
223
|
+
log.info("loading model pool")
|
|
224
|
+
self.modelpool = self._instantiate_and_setup(
|
|
225
|
+
self._modelpool,
|
|
226
|
+
compat_load_fn="fusion_bench.compat.modelpool.load_modelpool_from_config",
|
|
227
|
+
)
|
|
228
|
+
log.info("loading method")
|
|
229
|
+
self.method = self._instantiate_and_setup(
|
|
230
|
+
self._method,
|
|
231
|
+
compat_load_fn="fusion_bench.compat.method.load_algorithm_from_config",
|
|
232
|
+
)
|
|
233
|
+
if self._taskpool is not None:
|
|
234
|
+
log.info("loading task pool")
|
|
235
|
+
self.taskpool = self._instantiate_and_setup(
|
|
236
|
+
self._taskpool,
|
|
237
|
+
compat_load_fn="fusion_bench.compat.taskpool.load_taskpool_from_config",
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
self.method.on_run_start()
|
|
241
|
+
merged_model = self.method.run(self.modelpool)
|
|
242
|
+
self.method.on_run_end()
|
|
243
|
+
|
|
244
|
+
if merged_model is None:
|
|
245
|
+
log.info(
|
|
246
|
+
"No merged model returned by the method. Skipping saving and evaluation."
|
|
247
|
+
)
|
|
248
|
+
else:
|
|
249
|
+
self.save_merged_model(merged_model)
|
|
250
|
+
if self.taskpool is not None:
|
|
251
|
+
report = self.evaluate_merged_model(self.taskpool, merged_model)
|
|
252
|
+
try:
|
|
253
|
+
if rank_zero_only.rank == 0:
|
|
254
|
+
print_json(report, print_type=False)
|
|
255
|
+
except Exception as e:
|
|
256
|
+
log.warning(f"Failed to pretty print the report: {e}")
|
|
257
|
+
log.info(report)
|
|
258
|
+
if self.report_save_path is not None:
|
|
259
|
+
# save report (Dict) to a file
|
|
260
|
+
# if the directory of `save_report` does not exists, create it
|
|
261
|
+
if (
|
|
262
|
+
"{log_dir}" in self.report_save_path
|
|
263
|
+
and self.path.log_dir is not None
|
|
264
|
+
):
|
|
265
|
+
self.report_save_path = self.report_save_path.format(
|
|
266
|
+
log_dir=self.path.log_dir
|
|
267
|
+
)
|
|
268
|
+
os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
|
|
269
|
+
json.dump(report, open(self.report_save_path, "w"))
|
|
270
|
+
else:
|
|
271
|
+
log.info("No task pool specified. Skipping evaluation.")
|
fusion_bench/scripts/cli.py
CHANGED
|
@@ -20,8 +20,8 @@ log = logging.getLogger(__name__)
|
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
def _get_default_config_path():
|
|
23
|
-
for
|
|
24
|
-
for
|
|
23
|
+
for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
|
|
24
|
+
for config_dir in ["config", "fusion_bench_config"]:
|
|
25
25
|
config_path = os.path.join(config_path_root, config_dir)
|
|
26
26
|
if os.path.exists(config_path) and os.path.isdir(config_path):
|
|
27
27
|
return os.path.abspath(config_path)
|