fusion-bench 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +18 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  7. fusion_bench/method/ensemble.py +17 -2
  8. fusion_bench/method/linear/__init__.py +6 -2
  9. fusion_bench/method/linear/{simple_average_for_llama.py → simple_average_for_causallm.py} +8 -4
  10. fusion_bench/method/linear/{task_arithmetic_for_llama.py → task_arithmetic_for_causallm.py} +22 -12
  11. fusion_bench/method/linear/ties_merging_for_causallm.py +70 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/simple_average.py +2 -2
  15. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  16. fusion_bench/method/task_arithmetic/task_arithmetic.py +35 -10
  17. fusion_bench/method/ties_merging/ties_merging.py +22 -6
  18. fusion_bench/method/wudi/__init__.py +1 -0
  19. fusion_bench/method/wudi/wudi.py +105 -0
  20. fusion_bench/mixins/__init__.py +2 -0
  21. fusion_bench/mixins/lightning_fabric.py +4 -0
  22. fusion_bench/mixins/pyinstrument.py +174 -0
  23. fusion_bench/mixins/serialization.py +25 -78
  24. fusion_bench/mixins/simple_profiler.py +106 -23
  25. fusion_bench/modelpool/__init__.py +2 -0
  26. fusion_bench/modelpool/base_pool.py +77 -14
  27. fusion_bench/modelpool/causal_lm/causal_lm.py +32 -10
  28. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  29. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  30. fusion_bench/models/__init__.py +35 -9
  31. fusion_bench/models/hf_clip.py +4 -0
  32. fusion_bench/models/hf_utils.py +2 -1
  33. fusion_bench/models/model_card_templates/default.md +8 -1
  34. fusion_bench/models/wrappers/ensemble.py +136 -7
  35. fusion_bench/optim/__init__.py +40 -2
  36. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  37. fusion_bench/optim/muon.py +339 -0
  38. fusion_bench/programs/__init__.py +2 -0
  39. fusion_bench/programs/fabric_fusion_program.py +2 -2
  40. fusion_bench/programs/fusion_program.py +271 -0
  41. fusion_bench/scripts/cli.py +2 -2
  42. fusion_bench/taskpool/clip_vision/taskpool.py +11 -4
  43. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  44. fusion_bench/utils/__init__.py +167 -21
  45. fusion_bench/utils/devices.py +30 -8
  46. fusion_bench/utils/lazy_imports.py +91 -12
  47. fusion_bench/utils/lazy_state_dict.py +58 -5
  48. fusion_bench/utils/misc.py +104 -13
  49. fusion_bench/utils/packages.py +4 -0
  50. fusion_bench/utils/path.py +7 -0
  51. fusion_bench/utils/pylogger.py +6 -0
  52. fusion_bench/utils/rich_utils.py +8 -3
  53. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  54. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/METADATA +10 -3
  55. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/RECORD +76 -55
  56. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  57. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  58. fusion_bench_config/method/ensemble/simple_ensemble.yaml +1 -0
  59. fusion_bench_config/method/linear/{simple_average_for_llama.yaml → simple_average_for_causallm.yaml} +1 -1
  60. fusion_bench_config/method/linear/task_arithmetic_for_causallm.yaml +4 -0
  61. fusion_bench_config/method/linear/ties_merging_for_causallm.yaml +13 -0
  62. fusion_bench_config/method/wudi/wudi.yaml +4 -0
  63. fusion_bench_config/model_fusion.yaml +45 -0
  64. fusion_bench_config/modelpool/CausalLMPool/{Qwen2.5-1.5B_math_and_coder.yaml → Qwen2.5-1.5B_math_and_code.yaml} +1 -2
  65. fusion_bench_config/modelpool/CausalLMPool/Qwen2.5-1.5B_three_models.yaml +11 -0
  66. fusion_bench_config/modelpool/CausalLMPool/llama-7b_3-models_v1.yaml +11 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  72. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  73. fusion_bench_config/method/linear/task_arithmetic_for_llama.yaml +0 -4
  74. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/WHEEL +0 -0
  75. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/entry_points.txt +0 -0
  76. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/licenses/LICENSE +0 -0
  77. {fusion_bench-0.2.23.dist-info → fusion_bench-0.2.25.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
1
+ # Copied from https://github.com/KellerJordan/Muon
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def zeropower_via_newtonschulz5(G, steps: int):
7
+ """
8
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
9
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
10
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
11
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
12
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
13
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
14
+ performance at all relative to UV^T, where USV^T = G is the SVD.
15
+ """
16
+ assert (
17
+ G.ndim >= 2
18
+ ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
19
+ a, b, c = (3.4445, -4.7750, 2.0315)
20
+ X = G.bfloat16()
21
+ if G.size(-2) > G.size(-1):
22
+ X = X.mT
23
+
24
+ # Ensure spectral norm is at most 1
25
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
26
+ # Perform the NS iterations
27
+ for _ in range(steps):
28
+ A = X @ X.mT
29
+ B = (
30
+ b * A + c * A @ A
31
+ ) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32
+ X = a * X + B @ X
33
+
34
+ if G.size(-2) > G.size(-1):
35
+ X = X.mT
36
+ return X
37
+
38
+
39
+ def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
40
+ momentum.lerp_(grad, 1 - beta)
41
+ update = grad.lerp_(momentum, beta) if nesterov else momentum
42
+ if update.ndim == 4: # for the case of conv filters
43
+ update = update.view(len(update), -1)
44
+ update = zeropower_via_newtonschulz5(update, steps=ns_steps)
45
+ update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
46
+ return update
47
+
48
+
49
+ class Muon(torch.optim.Optimizer):
50
+ """
51
+ Muon - MomentUm Orthogonalized by Newton-schulz
52
+
53
+ https://kellerjordan.github.io/posts/muon/
54
+
55
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
56
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
57
+ matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
58
+ advantage that it can be stably run in bfloat16 on the GPU.
59
+
60
+ Muon should only be used for hidden weight layers. The input embedding, final output layer,
61
+ and any internal gains or biases should be optimized using a standard method such as AdamW.
62
+ Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
63
+ collapsing their last 3 dimensions.
64
+
65
+ Arguments:
66
+ lr: The learning rate, in units of spectral norm per update.
67
+ weight_decay: The AdamW-style weight decay.
68
+ momentum: The momentum. A value of 0.95 here is usually fine.
69
+ """
70
+
71
+ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
72
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
73
+ assert (
74
+ isinstance(params, list)
75
+ and len(params) >= 1
76
+ and isinstance(params[0], torch.nn.Parameter)
77
+ )
78
+ params = sorted(params, key=lambda x: x.size(), reverse=True)
79
+ super().__init__(params, defaults)
80
+
81
+ @torch.no_grad()
82
+ def step(self, closure=None):
83
+
84
+ loss = None
85
+ if closure is not None:
86
+ with torch.enable_grad():
87
+ loss = closure()
88
+
89
+ for group in self.param_groups:
90
+ params = group["params"]
91
+ params_pad = params + [torch.empty_like(params[-1])] * (
92
+ dist.get_world_size() - len(params) % dist.get_world_size()
93
+ )
94
+ for base_i in range(len(params))[:: dist.get_world_size()]:
95
+ if base_i + dist.get_rank() < len(params):
96
+ p = params[base_i + dist.get_rank()]
97
+ if p.grad is None:
98
+ # continue
99
+ p.grad = torch.zeros_like(p) # Force synchronization
100
+ state = self.state[p]
101
+ if len(state) == 0:
102
+ state["momentum_buffer"] = torch.zeros_like(p)
103
+ update = muon_update(
104
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
105
+ )
106
+ p.mul_(1 - group["lr"] * group["weight_decay"])
107
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
108
+ dist.all_gather(
109
+ params_pad[base_i : base_i + dist.get_world_size()],
110
+ params_pad[base_i + dist.get_rank()],
111
+ )
112
+
113
+ return loss
114
+
115
+
116
+ class SingleDeviceMuon(torch.optim.Optimizer):
117
+ """
118
+ Muon variant for usage in non-distributed settings.
119
+ """
120
+
121
+ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
122
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
123
+ super().__init__(params, defaults)
124
+
125
+ @torch.no_grad()
126
+ def step(self, closure=None):
127
+
128
+ loss = None
129
+ if closure is not None:
130
+ with torch.enable_grad():
131
+ loss = closure()
132
+
133
+ for group in self.param_groups:
134
+ for p in group["params"]:
135
+ if p.grad is None:
136
+ # continue
137
+ p.grad = torch.zeros_like(p) # Force synchronization
138
+ state = self.state[p]
139
+ if len(state) == 0:
140
+ state["momentum_buffer"] = torch.zeros_like(p)
141
+ update = muon_update(
142
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
143
+ )
144
+ p.mul_(1 - group["lr"] * group["weight_decay"])
145
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
146
+
147
+ return loss
148
+
149
+
150
+ def adam_update(grad, buf1, buf2, step, betas, eps):
151
+ buf1.lerp_(grad, 1 - betas[0])
152
+ buf2.lerp_(grad.square(), 1 - betas[1])
153
+ buf1c = buf1 / (1 - betas[0] ** step)
154
+ buf2c = buf2 / (1 - betas[1] ** step)
155
+ return buf1c / (buf2c.sqrt() + eps)
156
+
157
+
158
+ class MuonWithAuxAdam(torch.optim.Optimizer):
159
+ """
160
+ Distributed Muon variant that can be used for all parameters in the network, since it runs an
161
+ internal AdamW for the parameters that are not compatible with Muon. The user must manually
162
+ specify which parameters shall be optimized with Muon and which with Adam by passing in a
163
+ list of param_groups with the `use_muon` flag set.
164
+
165
+ The point of this class is to allow the user to have a single optimizer in their code, rather
166
+ than having both a Muon and an Adam which each need to be stepped.
167
+
168
+ You can see an example usage below:
169
+
170
+ https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
171
+ ```
172
+ hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
173
+ embed_params = [p for n, p in model.named_parameters() if "embed" in n]
174
+ scalar_params = [p for p in model.parameters() if p.ndim < 2]
175
+ head_params = [model.lm_head.weight]
176
+
177
+ from muon import MuonWithAuxAdam
178
+ adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
179
+ adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
180
+ muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
181
+ param_groups = [*adam_groups, muon_group]
182
+ optimizer = MuonWithAuxAdam(param_groups)
183
+ ```
184
+ """
185
+
186
+ def __init__(self, param_groups):
187
+ for group in param_groups:
188
+ assert "use_muon" in group
189
+ if group["use_muon"]:
190
+ group["params"] = sorted(
191
+ group["params"], key=lambda x: x.size(), reverse=True
192
+ )
193
+ # defaults
194
+ group["lr"] = group.get("lr", 0.02)
195
+ group["momentum"] = group.get("momentum", 0.95)
196
+ group["weight_decay"] = group.get("weight_decay", 0)
197
+ assert set(group.keys()) == set(
198
+ ["params", "lr", "momentum", "weight_decay", "use_muon"]
199
+ )
200
+ else:
201
+ # defaults
202
+ group["lr"] = group.get("lr", 3e-4)
203
+ group["betas"] = group.get("betas", (0.9, 0.95))
204
+ group["eps"] = group.get("eps", 1e-10)
205
+ group["weight_decay"] = group.get("weight_decay", 0)
206
+ assert set(group.keys()) == set(
207
+ ["params", "lr", "betas", "eps", "weight_decay", "use_muon"]
208
+ )
209
+ super().__init__(param_groups, dict())
210
+
211
+ @torch.no_grad()
212
+ def step(self, closure=None):
213
+
214
+ loss = None
215
+ if closure is not None:
216
+ with torch.enable_grad():
217
+ loss = closure()
218
+
219
+ for group in self.param_groups:
220
+ if group["use_muon"]:
221
+ params = group["params"]
222
+ params_pad = params + [torch.empty_like(params[-1])] * (
223
+ dist.get_world_size() - len(params) % dist.get_world_size()
224
+ )
225
+ for base_i in range(len(params))[:: dist.get_world_size()]:
226
+ if base_i + dist.get_rank() < len(params):
227
+ p = params[base_i + dist.get_rank()]
228
+ if p.grad is None:
229
+ # continue
230
+ p.grad = torch.zeros_like(p) # Force synchronization
231
+ state = self.state[p]
232
+ if len(state) == 0:
233
+ state["momentum_buffer"] = torch.zeros_like(p)
234
+ update = muon_update(
235
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
236
+ )
237
+ p.mul_(1 - group["lr"] * group["weight_decay"])
238
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
239
+ dist.all_gather(
240
+ params_pad[base_i : base_i + dist.get_world_size()],
241
+ params_pad[base_i + dist.get_rank()],
242
+ )
243
+ else:
244
+ for p in group["params"]:
245
+ if p.grad is None:
246
+ # continue
247
+ p.grad = torch.zeros_like(p) # Force synchronization
248
+ state = self.state[p]
249
+ if len(state) == 0:
250
+ state["exp_avg"] = torch.zeros_like(p)
251
+ state["exp_avg_sq"] = torch.zeros_like(p)
252
+ state["step"] = 0
253
+ state["step"] += 1
254
+ update = adam_update(
255
+ p.grad,
256
+ state["exp_avg"],
257
+ state["exp_avg_sq"],
258
+ state["step"],
259
+ group["betas"],
260
+ group["eps"],
261
+ )
262
+ p.mul_(1 - group["lr"] * group["weight_decay"])
263
+ p.add_(update, alpha=-group["lr"])
264
+
265
+ return loss
266
+
267
+
268
+ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
269
+ """
270
+ Non-distributed variant of MuonWithAuxAdam.
271
+ """
272
+
273
+ def __init__(self, param_groups):
274
+ for group in param_groups:
275
+ assert "use_muon" in group
276
+ if group["use_muon"]:
277
+ # defaults
278
+ group["lr"] = group.get("lr", 0.02)
279
+ group["momentum"] = group.get("momentum", 0.95)
280
+ group["weight_decay"] = group.get("weight_decay", 0)
281
+ assert set(group.keys()) == set(
282
+ ["params", "lr", "momentum", "weight_decay", "use_muon"]
283
+ )
284
+ else:
285
+ # defaults
286
+ group["lr"] = group.get("lr", 3e-4)
287
+ group["betas"] = group.get("betas", (0.9, 0.95))
288
+ group["eps"] = group.get("eps", 1e-10)
289
+ group["weight_decay"] = group.get("weight_decay", 0)
290
+ assert set(group.keys()) == set(
291
+ ["params", "lr", "betas", "eps", "weight_decay", "use_muon"]
292
+ )
293
+ super().__init__(param_groups, dict())
294
+
295
+ @torch.no_grad()
296
+ def step(self, closure=None):
297
+
298
+ loss = None
299
+ if closure is not None:
300
+ with torch.enable_grad():
301
+ loss = closure()
302
+
303
+ for group in self.param_groups:
304
+ if group["use_muon"]:
305
+ for p in group["params"]:
306
+ if p.grad is None:
307
+ # continue
308
+ p.grad = torch.zeros_like(p) # Force synchronization
309
+ state = self.state[p]
310
+ if len(state) == 0:
311
+ state["momentum_buffer"] = torch.zeros_like(p)
312
+ update = muon_update(
313
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
314
+ )
315
+ p.mul_(1 - group["lr"] * group["weight_decay"])
316
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
317
+ else:
318
+ for p in group["params"]:
319
+ if p.grad is None:
320
+ # continue
321
+ p.grad = torch.zeros_like(p) # Force synchronization
322
+ state = self.state[p]
323
+ if len(state) == 0:
324
+ state["exp_avg"] = torch.zeros_like(p)
325
+ state["exp_avg_sq"] = torch.zeros_like(p)
326
+ state["step"] = 0
327
+ state["step"] += 1
328
+ update = adam_update(
329
+ p.grad,
330
+ state["exp_avg"],
331
+ state["exp_avg_sq"],
332
+ state["step"],
333
+ group["betas"],
334
+ group["eps"],
335
+ )
336
+ p.mul_(1 - group["lr"] * group["weight_decay"])
337
+ p.add_(update, alpha=-group["lr"])
338
+
339
+ return loss
@@ -7,11 +7,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
7
7
  _import_structure = {
8
8
  "base_program": ["BaseHydraProgram"],
9
9
  "fabric_fusion_program": ["FabricModelFusionProgram"],
10
+ "fusion_program": ["ModelFusionProgram"],
10
11
  }
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from .base_program import BaseHydraProgram
14
15
  from .fabric_fusion_program import FabricModelFusionProgram
16
+ from .fusion_program import ModelFusionProgram
15
17
  else:
16
18
  sys.modules[__name__] = LazyImporter(
17
19
  __name__,
@@ -5,7 +5,7 @@ from pathlib import Path
5
5
  from typing import Any, Callable, Dict, Iterable, List, Optional, Union # noqa: F401
6
6
 
7
7
  import lightning as L
8
- from lightning.fabric.utilities.rank_zero import rank_zero_only
8
+ from lightning_utilities.core.rank_zero import rank_zero_only
9
9
  from omegaconf import DictConfig, OmegaConf
10
10
  from torch import nn
11
11
  from tqdm.auto import tqdm
@@ -236,7 +236,7 @@ class FabricModelFusionProgram(
236
236
 
237
237
  # create symbol link to hydra output directory
238
238
  if (
239
- self.fabric.is_global_zero
239
+ rank_zero_only.rank == 0
240
240
  and self.log_dir is not None
241
241
  and os.path.abspath(self.log_dir) != os.path.abspath(get_hydra_output_dir())
242
242
  ):
@@ -0,0 +1,271 @@
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, Iterable, List, Optional, Union
4
+
5
+ import lightning as L
6
+ from lightning_utilities.core.rank_zero import rank_zero_only
7
+ from omegaconf import DictConfig, OmegaConf
8
+ from torch import nn
9
+ from tqdm.auto import tqdm
10
+
11
+ from fusion_bench import (
12
+ BaseAlgorithm,
13
+ BaseHydraProgram,
14
+ BaseModelPool,
15
+ BaseTaskPool,
16
+ RuntimeConstants,
17
+ auto_register_config,
18
+ get_rankzero_logger,
19
+ import_object,
20
+ instantiate,
21
+ timeit_context,
22
+ )
23
+ from fusion_bench.utils.json import print_json
24
+ from fusion_bench.utils.rich_utils import print_bordered, print_config_tree
25
+
26
+ log = get_rankzero_logger(__name__)
27
+
28
+
29
+ @auto_register_config
30
+ class ModelFusionProgram(BaseHydraProgram):
31
+ method: BaseAlgorithm
32
+ modelpool: BaseModelPool
33
+ taskpool: Optional[BaseTaskPool] = None
34
+
35
+ _config_mapping = BaseHydraProgram._config_mapping | {
36
+ "_method": "method",
37
+ "_modelpool": "modelpool",
38
+ "_taskpool": "taskpool",
39
+ "fast_dev_run": "fast_dev_run",
40
+ "seed": "seed",
41
+ "path": "path",
42
+ }
43
+
44
+ def __init__(
45
+ self,
46
+ method: DictConfig,
47
+ modelpool: DictConfig,
48
+ taskpool: Optional[DictConfig] = None,
49
+ *,
50
+ print_config: bool = True,
51
+ dry_run: bool = False,
52
+ report_save_path: Optional[str] = None,
53
+ merged_model_save_path: Optional[str] = None,
54
+ merged_model_save_kwargs: Optional[DictConfig] = None,
55
+ fast_dev_run: bool = False,
56
+ seed: Optional[int] = None,
57
+ print_function_call: bool = True,
58
+ path: DictConfig = None,
59
+ **kwargs,
60
+ ):
61
+ super().__init__(**kwargs)
62
+ self._method = method
63
+ self._modelpool = modelpool
64
+ self._taskpool = taskpool
65
+ self.report_save_path = report_save_path
66
+ self.merged_model_save_path = merged_model_save_path
67
+ self.merged_model_save_kwargs = merged_model_save_kwargs
68
+ self.fast_dev_run = fast_dev_run
69
+ self.seed = seed
70
+ self.path = path
71
+ RuntimeConstants.debug = fast_dev_run
72
+ RuntimeConstants.print_function_call = print_function_call
73
+ if path is not None:
74
+ RuntimeConstants.cache_dir = path.get("cache_dir", None)
75
+
76
+ if print_config:
77
+ print_config_tree(
78
+ self.config,
79
+ print_order=["method", "modelpool", "taskpool"],
80
+ )
81
+ if dry_run:
82
+ log.info("The program is running in dry-run mode. Exiting.")
83
+ exit(0)
84
+
85
+ def _instantiate_and_setup(
86
+ self, config: DictConfig, compat_load_fn: Optional[str] = None
87
+ ):
88
+ R"""
89
+ Instantiates and sets up an object based on the provided configuration.
90
+
91
+ This method performs the following steps:
92
+ 1. Checks if the configuration dictionary contains the key "_target_".
93
+ 2. If "_target_" is not found (for v0.1.x), attempts to instantiate the object using a compatible load function if provided.
94
+ - Logs a warning if "_target_" is missing.
95
+ - If `compat_load_fn` is provided, imports the function and uses it to instantiate the object.
96
+ - If `compat_load_fn` is not provided, raises a ValueError.
97
+ 3. If "_target_" is found (for v.0.2.0 and above), attempts to import and instantiate the object using the `instantiate` function.
98
+ - Ensures the target can be imported.
99
+ - Uses the `instantiate` function with `_recursive_` set based on the configuration.
100
+ 4. Sets the `_program` attribute of the instantiated object to `self` if the object has this attribute.
101
+ 5. Returns the instantiated and set up object.
102
+ """
103
+ if "_target_" not in config:
104
+ log.warning(
105
+ "No '_target_' key found in config. Attempting to instantiate the object in a compatible way."
106
+ )
107
+ if compat_load_fn is not None:
108
+ compat_load_fn = import_object(compat_load_fn)
109
+ if rank_zero_only.rank == 0:
110
+ print_bordered(
111
+ OmegaConf.to_yaml(config),
112
+ title="instantiate compat object",
113
+ style="magenta",
114
+ code_style="yaml",
115
+ )
116
+ obj = compat_load_fn(config)
117
+ else:
118
+ raise ValueError(
119
+ "No load function provided. Please provide a load function to instantiate the object."
120
+ )
121
+ else:
122
+ # try to import the object from the target
123
+ # this checks if the target is valid and can be imported
124
+ import_object(config._target_)
125
+ obj = instantiate(
126
+ config,
127
+ _recursive_=config.get("_recursive_", False),
128
+ )
129
+ if hasattr(obj, "_program"):
130
+ obj._program = self
131
+ return obj
132
+
133
+ def save_merged_model(self, merged_model):
134
+ """
135
+ Saves the merged model to the specified path.
136
+ """
137
+ if self.merged_model_save_path is not None:
138
+ # path to save the merged model, use "{log_dir}" to refer to the logger directory
139
+ save_path: str = self.merged_model_save_path
140
+ if "{log_dir}" in save_path and self.log_dir is not None:
141
+ save_path = save_path.format(log_dir=self.log_dir)
142
+
143
+ if os.path.dirname(save_path):
144
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
145
+
146
+ # save the merged model
147
+ if self.merged_model_save_kwargs is not None:
148
+ merged_model_save_kwargs = self.merged_model_save_kwargs
149
+ else:
150
+ merged_model_save_kwargs = {}
151
+ with timeit_context(f"Saving the merged model to {save_path}"):
152
+ self.modelpool.save_model(
153
+ merged_model,
154
+ save_path,
155
+ **merged_model_save_kwargs,
156
+ )
157
+ else:
158
+ print("No save path specified for the merged model. Skipping saving.")
159
+
160
+ def evaluate_merged_model(
161
+ self,
162
+ taskpool: BaseTaskPool,
163
+ merged_model: Union[nn.Module, Dict, Iterable],
164
+ *args: Any,
165
+ **kwargs: Any,
166
+ ) -> Union[Dict, List, Any]:
167
+ """
168
+ Evaluates the merged model using the provided task pool.
169
+
170
+ Depending on the type of the merged model, this function handles the evaluation differently:
171
+ - If the merged model is an instance of `nn.Module`, it directly evaluates the model.
172
+ - If the merged model is a dictionary, it extracts the model from the dictionary and evaluates it.
173
+ The evaluation report is then updated with the remaining dictionary items.
174
+ - If the merged model is an iterable, it recursively evaluates each model in the iterable.
175
+ - Raises a `ValueError` if the merged model is of an invalid type.
176
+
177
+ Args:
178
+ taskpool: The task pool used for evaluating the merged model.
179
+ merged_model: The merged model to be evaluated. It can be an instance of `nn.Module`, a dictionary, or an iterable.
180
+ *args: Additional positional arguments to be passed to the `evaluate` method of the taskpool.
181
+ **kwargs: Additional keyword arguments to be passed to the `evaluate` method of the taskpool.
182
+
183
+ Returns:
184
+ The evaluation report. The type of the report depends on the type of the merged model:
185
+ - If the merged model is an instance of `nn.Module`, the report is a dictionary.
186
+ - If the merged model is a dictionary, the report is a dictionary updated with the remaining dictionary items.
187
+ - If the merged model is an iterable, the report is a list of evaluation reports.
188
+ """
189
+ if isinstance(merged_model, nn.Module):
190
+ report = taskpool.evaluate(merged_model, *args, **kwargs)
191
+ return report
192
+ elif isinstance(merged_model, Dict):
193
+ report = {}
194
+ for key, item in merged_model.items():
195
+ if isinstance(item, nn.Module):
196
+ report[key] = taskpool.evaluate(item, *args, **kwargs)
197
+ elif key == "models":
198
+ # for multi-model evaluation
199
+ report[key] = self.evaluate_merged_model(
200
+ taskpool, item, *args, **kwargs
201
+ )
202
+ else:
203
+ # metadata
204
+ report[key] = item
205
+ return report
206
+ elif isinstance(merged_model, Iterable):
207
+ return [
208
+ self.evaluate_merged_model(taskpool, m, *args, **kwargs)
209
+ for m in tqdm(merged_model, desc="Evaluating models")
210
+ ]
211
+ else:
212
+ raise ValueError(f"Invalid type for merged model: {type(merged_model)}")
213
+
214
+ def run(self):
215
+ """
216
+ Executes the model fusion program.
217
+ """
218
+ if self.seed is not None:
219
+ L.seed_everything(self.seed)
220
+
221
+ log.info("Running the model fusion program.")
222
+ # setup the modelpool, method, and taskpool
223
+ log.info("loading model pool")
224
+ self.modelpool = self._instantiate_and_setup(
225
+ self._modelpool,
226
+ compat_load_fn="fusion_bench.compat.modelpool.load_modelpool_from_config",
227
+ )
228
+ log.info("loading method")
229
+ self.method = self._instantiate_and_setup(
230
+ self._method,
231
+ compat_load_fn="fusion_bench.compat.method.load_algorithm_from_config",
232
+ )
233
+ if self._taskpool is not None:
234
+ log.info("loading task pool")
235
+ self.taskpool = self._instantiate_and_setup(
236
+ self._taskpool,
237
+ compat_load_fn="fusion_bench.compat.taskpool.load_taskpool_from_config",
238
+ )
239
+
240
+ self.method.on_run_start()
241
+ merged_model = self.method.run(self.modelpool)
242
+ self.method.on_run_end()
243
+
244
+ if merged_model is None:
245
+ log.info(
246
+ "No merged model returned by the method. Skipping saving and evaluation."
247
+ )
248
+ else:
249
+ self.save_merged_model(merged_model)
250
+ if self.taskpool is not None:
251
+ report = self.evaluate_merged_model(self.taskpool, merged_model)
252
+ try:
253
+ if rank_zero_only.rank == 0:
254
+ print_json(report, print_type=False)
255
+ except Exception as e:
256
+ log.warning(f"Failed to pretty print the report: {e}")
257
+ log.info(report)
258
+ if self.report_save_path is not None:
259
+ # save report (Dict) to a file
260
+ # if the directory of `save_report` does not exists, create it
261
+ if (
262
+ "{log_dir}" in self.report_save_path
263
+ and self.path.log_dir is not None
264
+ ):
265
+ self.report_save_path = self.report_save_path.format(
266
+ log_dir=self.path.log_dir
267
+ )
268
+ os.makedirs(os.path.dirname(self.report_save_path), exist_ok=True)
269
+ json.dump(report, open(self.report_save_path, "w"))
270
+ else:
271
+ log.info("No task pool specified. Skipping evaluation.")
@@ -20,8 +20,8 @@ log = logging.getLogger(__name__)
20
20
 
21
21
 
22
22
  def _get_default_config_path():
23
- for config_dir in ["fusion_bench_config", "config"]:
24
- for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
23
+ for config_path_root in [os.getcwd(), PROJECT_ROOT_PATH]:
24
+ for config_dir in ["config", "fusion_bench_config"]:
25
25
  config_path = os.path.join(config_path_root, config_dir)
26
26
  if os.path.exists(config_path) and os.path.isdir(config_path):
27
27
  return os.path.abspath(config_path)