fusion-bench 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. fusion_bench/__init__.py +152 -42
  2. fusion_bench/dataset/__init__.py +27 -4
  3. fusion_bench/dataset/clip_dataset.py +2 -2
  4. fusion_bench/method/__init__.py +12 -1
  5. fusion_bench/method/classification/__init__.py +27 -2
  6. fusion_bench/method/classification/clip_finetune.py +6 -4
  7. fusion_bench/method/classification/image_classification_finetune.py +214 -0
  8. fusion_bench/method/dop/__init__.py +1 -0
  9. fusion_bench/method/dop/dop.py +366 -0
  10. fusion_bench/method/dop/min_norm_solvers.py +227 -0
  11. fusion_bench/method/dop/utils.py +73 -0
  12. fusion_bench/method/opcm/opcm.py +1 -0
  13. fusion_bench/method/pwe_moe/module.py +0 -2
  14. fusion_bench/method/tall_mask/task_arithmetic.py +2 -2
  15. fusion_bench/mixins/__init__.py +2 -0
  16. fusion_bench/mixins/pyinstrument.py +174 -0
  17. fusion_bench/mixins/simple_profiler.py +106 -23
  18. fusion_bench/modelpool/__init__.py +2 -0
  19. fusion_bench/modelpool/base_pool.py +77 -14
  20. fusion_bench/modelpool/clip_vision/modelpool.py +56 -19
  21. fusion_bench/modelpool/resnet_for_image_classification.py +208 -0
  22. fusion_bench/models/__init__.py +35 -9
  23. fusion_bench/optim/__init__.py +40 -2
  24. fusion_bench/optim/lr_scheduler/__init__.py +27 -1
  25. fusion_bench/optim/muon.py +339 -0
  26. fusion_bench/programs/__init__.py +2 -0
  27. fusion_bench/programs/fabric_fusion_program.py +2 -2
  28. fusion_bench/programs/fusion_program.py +271 -0
  29. fusion_bench/tasks/clip_classification/__init__.py +15 -0
  30. fusion_bench/utils/__init__.py +167 -21
  31. fusion_bench/utils/lazy_imports.py +91 -12
  32. fusion_bench/utils/lazy_state_dict.py +55 -5
  33. fusion_bench/utils/misc.py +104 -13
  34. fusion_bench/utils/packages.py +4 -0
  35. fusion_bench/utils/path.py +7 -0
  36. fusion_bench/utils/pylogger.py +6 -0
  37. fusion_bench/utils/rich_utils.py +1 -0
  38. fusion_bench/utils/state_dict_arithmetic.py +935 -162
  39. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/METADATA +8 -2
  40. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/RECORD +75 -56
  41. fusion_bench_config/method/bitdelta/bitdelta.yaml +3 -0
  42. fusion_bench_config/method/classification/image_classification_finetune.yaml +16 -0
  43. fusion_bench_config/method/classification/image_classification_finetune_test.yaml +6 -0
  44. fusion_bench_config/method/depth_upscaling.yaml +9 -0
  45. fusion_bench_config/method/dop/dop.yaml +30 -0
  46. fusion_bench_config/method/dummy.yaml +6 -0
  47. fusion_bench_config/method/ensemble/max_model_predictor.yaml +6 -0
  48. fusion_bench_config/method/ensemble/simple_ensemble.yaml +8 -1
  49. fusion_bench_config/method/ensemble/weighted_ensemble.yaml +8 -0
  50. fusion_bench_config/method/linear/linear_interpolation.yaml +8 -0
  51. fusion_bench_config/method/linear/weighted_average.yaml +3 -0
  52. fusion_bench_config/method/linear/weighted_average_for_llama.yaml +1 -1
  53. fusion_bench_config/method/model_recombination.yaml +8 -0
  54. fusion_bench_config/method/model_stock/model_stock.yaml +4 -1
  55. fusion_bench_config/method/opcm/opcm.yaml +5 -0
  56. fusion_bench_config/method/opcm/task_arithmetic.yaml +6 -0
  57. fusion_bench_config/method/opcm/ties_merging.yaml +5 -0
  58. fusion_bench_config/method/opcm/weight_average.yaml +5 -0
  59. fusion_bench_config/method/simple_average.yaml +9 -0
  60. fusion_bench_config/method/slerp/slerp.yaml +9 -0
  61. fusion_bench_config/method/slerp/slerp_lm.yaml +5 -0
  62. fusion_bench_config/method/smile_upscaling/smile_upscaling.yaml +3 -0
  63. fusion_bench_config/method/task_arithmetic.yaml +9 -0
  64. fusion_bench_config/method/ties_merging.yaml +3 -0
  65. fusion_bench_config/model_fusion.yaml +45 -0
  66. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar10.yaml +14 -0
  67. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet152_cifar100.yaml +14 -0
  68. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar10.yaml +14 -0
  69. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet18_cifar100.yaml +14 -0
  70. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar10.yaml +14 -0
  71. fusion_bench_config/modelpool/ResNetForImageClassfication/transformers/resnet50_cifar100.yaml +14 -0
  72. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/WHEEL +0 -0
  73. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/entry_points.txt +0 -0
  74. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/licenses/LICENSE +0 -0
  75. {fusion_bench-0.2.24.dist-info → fusion_bench-0.2.26.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,208 @@
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ Callable,
5
+ Dict,
6
+ Literal,
7
+ Optional,
8
+ TypeVar,
9
+ Union,
10
+ override,
11
+ )
12
+
13
+ import torch
14
+ from omegaconf import DictConfig
15
+ from torch import nn
16
+
17
+ from fusion_bench import BaseModelPool, auto_register_config, get_rankzero_logger
18
+ from fusion_bench.tasks.clip_classification import get_classnames, get_num_classes
19
+
20
+ if TYPE_CHECKING:
21
+ from torchvision.models import ResNet as TorchVisionResNet
22
+
23
+ log = get_rankzero_logger(__name__)
24
+
25
+
26
+ def load_torchvision_resnet(
27
+ model_name: str, weights: Optional[str], num_classes: Optional[int]
28
+ ) -> "TorchVisionResNet":
29
+ import torchvision.models
30
+
31
+ model_fn = getattr(torchvision.models, model_name)
32
+ model: "TorchVisionResNet" = model_fn(weights=weights)
33
+
34
+ if num_classes is not None:
35
+ model.fc = nn.Linear(model.fc.in_features, num_classes)
36
+
37
+ return model
38
+
39
+
40
+ def load_transformers_resnet(
41
+ config_path: str, pretrained: bool, dataset_name: Optional[str]
42
+ ):
43
+ from transformers import AutoConfig, ResNetForImageClassification
44
+
45
+ if pretrained:
46
+ model = ResNetForImageClassification.from_pretrained(config_path)
47
+ else:
48
+ config = AutoConfig.from_pretrained(config_path)
49
+ model = ResNetForImageClassification(config)
50
+
51
+ if dataset_name is None:
52
+ return model
53
+
54
+ classnames = get_classnames(dataset_name)
55
+ id2label = {i: c for i, c in enumerate(classnames)}
56
+ label2id = {c: i for i, c in enumerate(classnames)}
57
+ model.config.id2label = id2label
58
+ model.config.label2id = label2id
59
+
60
+ model.classifier[1] = (
61
+ nn.Linear(
62
+ model.classifier[1].in_features,
63
+ len(classnames),
64
+ )
65
+ if model.config.num_labels > 0
66
+ else nn.Identity()
67
+ )
68
+ return model
69
+
70
+
71
+ @auto_register_config
72
+ class ResNetForImageClassificationPool(BaseModelPool):
73
+ def __init__(self, type: str, **kwargs):
74
+ super().__init__(**kwargs)
75
+ assert type in ["torchvision", "transformers"]
76
+
77
+ def load_processor(
78
+ self, stage: Literal["train", "val", "test"] = "test", *args, **kwargs
79
+ ):
80
+ if self.type == "torchvision":
81
+ from torchvision import transforms
82
+
83
+ to_tensor = transforms.ToTensor()
84
+ normalize = transforms.Normalize(
85
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
86
+ )
87
+ if stage == "train":
88
+ train_transform = transforms.Compose(
89
+ [
90
+ transforms.RandomResizedCrop(224),
91
+ transforms.RandomHorizontalFlip(),
92
+ to_tensor,
93
+ normalize,
94
+ ]
95
+ )
96
+ return train_transform
97
+ else:
98
+ val_transform = transforms.Compose(
99
+ [
100
+ transforms.Resize(256),
101
+ transforms.CenterCrop(224),
102
+ to_tensor,
103
+ normalize,
104
+ ]
105
+ )
106
+ return val_transform
107
+
108
+ elif self.type == "transformers":
109
+ from transformers import AutoImageProcessor
110
+
111
+ if self.has_pretrained:
112
+ config_path = self._models["_pretrained_"].config_path
113
+ else:
114
+ for model_cfg in self._models.values():
115
+ if isinstance(model_cfg, str):
116
+ config_path = model_cfg
117
+ break
118
+ if "config_path" in model_cfg:
119
+ config_path = model_cfg["config_path"]
120
+ break
121
+ return AutoImageProcessor.from_pretrained(config_path)
122
+
123
+ @override
124
+ def load_model(self, model_name_or_config: Union[str, DictConfig], *args, **kwargs):
125
+ log.debug(f"Loading model: {model_name_or_config}", stacklevel=2)
126
+ if (
127
+ isinstance(model_name_or_config, str)
128
+ and model_name_or_config in self._models
129
+ ):
130
+ model_name_or_config = self._models[model_name_or_config]
131
+
132
+ if self.type == "torchvision":
133
+ from torchvision.models import (
134
+ resnet18,
135
+ resnet34,
136
+ resnet50,
137
+ resnet101,
138
+ resnet152,
139
+ )
140
+
141
+ match model_name_or_config:
142
+ case "resnet18":
143
+ model = resnet18()
144
+ case "resnet34":
145
+ model = resnet34()
146
+ case "resnet50":
147
+ model = resnet50()
148
+ case "resnet101":
149
+ model = resnet101()
150
+ case "resnet152":
151
+ model = resnet152()
152
+ case dict() | DictConfig() as model_config:
153
+ if "dataset_name" in model_config:
154
+ num_classes = get_num_classes(model_config["dataset_name"])
155
+ if "num_classes" in model_config:
156
+ assert (
157
+ num_classes == model_config["num_classes"]
158
+ ), f"num_classes mismatch: {num_classes} vs {model_config['num_classes']}"
159
+ elif "num_classes" in model_config:
160
+ num_classes = model_config["num_classes"]
161
+ else:
162
+ num_classes = None
163
+ model = load_torchvision_resnet(
164
+ model_name=model_config["model_name"],
165
+ weights=model_config.get("weights", None),
166
+ num_classes=num_classes,
167
+ )
168
+ case _:
169
+ raise ValueError(
170
+ f"Invalid model_name_or_config type: {type(model_name_or_config)}"
171
+ )
172
+ elif self.type == "transformers":
173
+ match model_name_or_config:
174
+ case str() as model_path:
175
+ from transformers import AutoModelForImageClassification
176
+
177
+ model = AutoModelForImageClassification.from_pretrained(model_path)
178
+ case dict() | DictConfig() as model_config:
179
+
180
+ model = load_transformers_resnet(
181
+ config_path=model_config["config_path"],
182
+ pretrained=model_config.get("pretrained", False),
183
+ dataset_name=model_config.get("dataset_name", None),
184
+ )
185
+ case _:
186
+ raise ValueError(
187
+ f"Invalid model_name_or_config type: {type(model_name_or_config)}"
188
+ )
189
+
190
+ # override forward to return logits only
191
+ original_forward = model.forward
192
+ model.forward = lambda pixel_values, **kwargs: original_forward(
193
+ pixel_values=pixel_values, **kwargs
194
+ ).logits
195
+ model.original_forward = original_forward
196
+ else:
197
+ raise ValueError(f"Unknown model type: {self.type}")
198
+ return model
199
+
200
+ @override
201
+ def save_model(self, model, path, *args, **kwargs):
202
+ if self.type == "torchvision":
203
+ torch.save(model.state_dict(), path)
204
+ elif self.type == "transformers":
205
+ model.save_pretrained(path)
206
+ self.load_processor().save_pretrained(path)
207
+ else:
208
+ raise ValueError(f"Unknown model type: {self.type}")
@@ -1,10 +1,36 @@
1
1
  # flake8: noqa F401
2
- from fusion_bench.utils import LazyStateDict
3
-
4
- from . import separate_io, utils
5
- from .hf_utils import (
6
- create_default_model_card,
7
- load_model_card_template,
8
- save_pretrained_with_remote_code,
9
- )
10
- from .parameter_dict import ParameterDictModel
2
+ import sys
3
+ from typing import TYPE_CHECKING
4
+
5
+ from fusion_bench.utils.lazy_imports import LazyImporter
6
+
7
+ from . import utils
8
+
9
+ _extra_objects = {
10
+ "utils": utils,
11
+ }
12
+ _import_structure = {
13
+ "hf_utils": [
14
+ "create_default_model_card",
15
+ "load_model_card_template",
16
+ "save_pretrained_with_remote_code",
17
+ ],
18
+ "parameter_dict": ["ParameterDictModel"],
19
+ "separate_io": ["separate_load", "separate_save"],
20
+ }
21
+
22
+ if TYPE_CHECKING:
23
+ from .hf_utils import (
24
+ create_default_model_card,
25
+ load_model_card_template,
26
+ save_pretrained_with_remote_code,
27
+ )
28
+ from .parameter_dict import ParameterDictModel
29
+ from .separate_io import separate_load, separate_save
30
+ else:
31
+ sys.modules[__name__] = LazyImporter(
32
+ __name__,
33
+ globals()["__file__"],
34
+ _import_structure,
35
+ extra_objects=_extra_objects,
36
+ )
@@ -1,2 +1,40 @@
1
- from . import exception, lr_scheduler
2
- from .mezo import MeZO
1
+ import sys
2
+ from typing import TYPE_CHECKING
3
+
4
+ from fusion_bench.utils.lazy_imports import LazyImporter
5
+
6
+ from . import lr_scheduler
7
+
8
+ _extra_objects = {
9
+ "lr_scheduler": lr_scheduler,
10
+ }
11
+ _import_structure = {
12
+ "exception": [
13
+ "NoClosureError",
14
+ "NoSparseGradientError",
15
+ "NegativeLRError",
16
+ "NegativeStepError",
17
+ "ZeroParameterSizeError",
18
+ ],
19
+ "mezo": ["MeZO"],
20
+ "muon": ["Muon"],
21
+ }
22
+
23
+ if TYPE_CHECKING:
24
+ from .exception import (
25
+ NegativeLRError,
26
+ NegativeStepError,
27
+ NoClosureError,
28
+ NoSparseGradientError,
29
+ ZeroParameterSizeError,
30
+ )
31
+ from .mezo import MeZO
32
+ from .muon import Muon
33
+
34
+ else:
35
+ sys.modules[__name__] = LazyImporter(
36
+ __name__,
37
+ globals()["__file__"],
38
+ _import_structure,
39
+ extra_objects=_extra_objects,
40
+ )
@@ -1 +1,27 @@
1
- from .linear_warmup import *
1
+ import sys
2
+ from typing import TYPE_CHECKING
3
+
4
+ from fusion_bench.utils.lazy_imports import LazyImporter
5
+
6
+ _import_structure = {
7
+ "linear_warmup": [
8
+ "BaseLinearWarmupScheduler",
9
+ "LinearWarmupScheduler",
10
+ "CosineDecayWithWarmup",
11
+ "PolySchedulerWithWarmup",
12
+ ],
13
+ }
14
+
15
+ if TYPE_CHECKING:
16
+ from .linear_warmup import (
17
+ BaseLinearWarmupScheduler,
18
+ CosineDecayWithWarmup,
19
+ LinearWarmupScheduler,
20
+ PolySchedulerWithWarmup,
21
+ )
22
+ else:
23
+ sys.modules[__name__] = LazyImporter(
24
+ __name__,
25
+ globals()["__file__"],
26
+ _import_structure,
27
+ )
@@ -0,0 +1,339 @@
1
+ # Copied from https://github.com/KellerJordan/Muon
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def zeropower_via_newtonschulz5(G, steps: int):
7
+ """
8
+ Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
9
+ quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
10
+ of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
11
+ zero even beyond the point where the iteration no longer converges all the way to one everywhere
12
+ on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
13
+ where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
14
+ performance at all relative to UV^T, where USV^T = G is the SVD.
15
+ """
16
+ assert (
17
+ G.ndim >= 2
18
+ ) # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
19
+ a, b, c = (3.4445, -4.7750, 2.0315)
20
+ X = G.bfloat16()
21
+ if G.size(-2) > G.size(-1):
22
+ X = X.mT
23
+
24
+ # Ensure spectral norm is at most 1
25
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
26
+ # Perform the NS iterations
27
+ for _ in range(steps):
28
+ A = X @ X.mT
29
+ B = (
30
+ b * A + c * A @ A
31
+ ) # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32
+ X = a * X + B @ X
33
+
34
+ if G.size(-2) > G.size(-1):
35
+ X = X.mT
36
+ return X
37
+
38
+
39
+ def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
40
+ momentum.lerp_(grad, 1 - beta)
41
+ update = grad.lerp_(momentum, beta) if nesterov else momentum
42
+ if update.ndim == 4: # for the case of conv filters
43
+ update = update.view(len(update), -1)
44
+ update = zeropower_via_newtonschulz5(update, steps=ns_steps)
45
+ update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
46
+ return update
47
+
48
+
49
+ class Muon(torch.optim.Optimizer):
50
+ """
51
+ Muon - MomentUm Orthogonalized by Newton-schulz
52
+
53
+ https://kellerjordan.github.io/posts/muon/
54
+
55
+ Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
56
+ processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
57
+ matrix. For efficient orthogonalization we use a Newton-Schulz iteration, which has the
58
+ advantage that it can be stably run in bfloat16 on the GPU.
59
+
60
+ Muon should only be used for hidden weight layers. The input embedding, final output layer,
61
+ and any internal gains or biases should be optimized using a standard method such as AdamW.
62
+ Hidden convolutional weights can be trained using Muon by viewing them as 2D and then
63
+ collapsing their last 3 dimensions.
64
+
65
+ Arguments:
66
+ lr: The learning rate, in units of spectral norm per update.
67
+ weight_decay: The AdamW-style weight decay.
68
+ momentum: The momentum. A value of 0.95 here is usually fine.
69
+ """
70
+
71
+ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
72
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
73
+ assert (
74
+ isinstance(params, list)
75
+ and len(params) >= 1
76
+ and isinstance(params[0], torch.nn.Parameter)
77
+ )
78
+ params = sorted(params, key=lambda x: x.size(), reverse=True)
79
+ super().__init__(params, defaults)
80
+
81
+ @torch.no_grad()
82
+ def step(self, closure=None):
83
+
84
+ loss = None
85
+ if closure is not None:
86
+ with torch.enable_grad():
87
+ loss = closure()
88
+
89
+ for group in self.param_groups:
90
+ params = group["params"]
91
+ params_pad = params + [torch.empty_like(params[-1])] * (
92
+ dist.get_world_size() - len(params) % dist.get_world_size()
93
+ )
94
+ for base_i in range(len(params))[:: dist.get_world_size()]:
95
+ if base_i + dist.get_rank() < len(params):
96
+ p = params[base_i + dist.get_rank()]
97
+ if p.grad is None:
98
+ # continue
99
+ p.grad = torch.zeros_like(p) # Force synchronization
100
+ state = self.state[p]
101
+ if len(state) == 0:
102
+ state["momentum_buffer"] = torch.zeros_like(p)
103
+ update = muon_update(
104
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
105
+ )
106
+ p.mul_(1 - group["lr"] * group["weight_decay"])
107
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
108
+ dist.all_gather(
109
+ params_pad[base_i : base_i + dist.get_world_size()],
110
+ params_pad[base_i + dist.get_rank()],
111
+ )
112
+
113
+ return loss
114
+
115
+
116
+ class SingleDeviceMuon(torch.optim.Optimizer):
117
+ """
118
+ Muon variant for usage in non-distributed settings.
119
+ """
120
+
121
+ def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
122
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
123
+ super().__init__(params, defaults)
124
+
125
+ @torch.no_grad()
126
+ def step(self, closure=None):
127
+
128
+ loss = None
129
+ if closure is not None:
130
+ with torch.enable_grad():
131
+ loss = closure()
132
+
133
+ for group in self.param_groups:
134
+ for p in group["params"]:
135
+ if p.grad is None:
136
+ # continue
137
+ p.grad = torch.zeros_like(p) # Force synchronization
138
+ state = self.state[p]
139
+ if len(state) == 0:
140
+ state["momentum_buffer"] = torch.zeros_like(p)
141
+ update = muon_update(
142
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
143
+ )
144
+ p.mul_(1 - group["lr"] * group["weight_decay"])
145
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
146
+
147
+ return loss
148
+
149
+
150
+ def adam_update(grad, buf1, buf2, step, betas, eps):
151
+ buf1.lerp_(grad, 1 - betas[0])
152
+ buf2.lerp_(grad.square(), 1 - betas[1])
153
+ buf1c = buf1 / (1 - betas[0] ** step)
154
+ buf2c = buf2 / (1 - betas[1] ** step)
155
+ return buf1c / (buf2c.sqrt() + eps)
156
+
157
+
158
+ class MuonWithAuxAdam(torch.optim.Optimizer):
159
+ """
160
+ Distributed Muon variant that can be used for all parameters in the network, since it runs an
161
+ internal AdamW for the parameters that are not compatible with Muon. The user must manually
162
+ specify which parameters shall be optimized with Muon and which with Adam by passing in a
163
+ list of param_groups with the `use_muon` flag set.
164
+
165
+ The point of this class is to allow the user to have a single optimizer in their code, rather
166
+ than having both a Muon and an Adam which each need to be stepped.
167
+
168
+ You can see an example usage below:
169
+
170
+ https://github.com/KellerJordan/modded-nanogpt/blob/master/records/052525_MuonWithAuxAdamExample/b01550f9-03d8-4a9c-86fe-4ab434f1c5e0.txt#L470
171
+ ```
172
+ hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
173
+ embed_params = [p for n, p in model.named_parameters() if "embed" in n]
174
+ scalar_params = [p for p in model.parameters() if p.ndim < 2]
175
+ head_params = [model.lm_head.weight]
176
+
177
+ from muon import MuonWithAuxAdam
178
+ adam_groups = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
179
+ adam_groups = [dict(**g, betas=(0.8, 0.95), eps=1e-10, use_muon=False) for g in adam_groups]
180
+ muon_group = dict(params=hidden_matrix_params, lr=0.05, momentum=0.95, use_muon=True)
181
+ param_groups = [*adam_groups, muon_group]
182
+ optimizer = MuonWithAuxAdam(param_groups)
183
+ ```
184
+ """
185
+
186
+ def __init__(self, param_groups):
187
+ for group in param_groups:
188
+ assert "use_muon" in group
189
+ if group["use_muon"]:
190
+ group["params"] = sorted(
191
+ group["params"], key=lambda x: x.size(), reverse=True
192
+ )
193
+ # defaults
194
+ group["lr"] = group.get("lr", 0.02)
195
+ group["momentum"] = group.get("momentum", 0.95)
196
+ group["weight_decay"] = group.get("weight_decay", 0)
197
+ assert set(group.keys()) == set(
198
+ ["params", "lr", "momentum", "weight_decay", "use_muon"]
199
+ )
200
+ else:
201
+ # defaults
202
+ group["lr"] = group.get("lr", 3e-4)
203
+ group["betas"] = group.get("betas", (0.9, 0.95))
204
+ group["eps"] = group.get("eps", 1e-10)
205
+ group["weight_decay"] = group.get("weight_decay", 0)
206
+ assert set(group.keys()) == set(
207
+ ["params", "lr", "betas", "eps", "weight_decay", "use_muon"]
208
+ )
209
+ super().__init__(param_groups, dict())
210
+
211
+ @torch.no_grad()
212
+ def step(self, closure=None):
213
+
214
+ loss = None
215
+ if closure is not None:
216
+ with torch.enable_grad():
217
+ loss = closure()
218
+
219
+ for group in self.param_groups:
220
+ if group["use_muon"]:
221
+ params = group["params"]
222
+ params_pad = params + [torch.empty_like(params[-1])] * (
223
+ dist.get_world_size() - len(params) % dist.get_world_size()
224
+ )
225
+ for base_i in range(len(params))[:: dist.get_world_size()]:
226
+ if base_i + dist.get_rank() < len(params):
227
+ p = params[base_i + dist.get_rank()]
228
+ if p.grad is None:
229
+ # continue
230
+ p.grad = torch.zeros_like(p) # Force synchronization
231
+ state = self.state[p]
232
+ if len(state) == 0:
233
+ state["momentum_buffer"] = torch.zeros_like(p)
234
+ update = muon_update(
235
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
236
+ )
237
+ p.mul_(1 - group["lr"] * group["weight_decay"])
238
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
239
+ dist.all_gather(
240
+ params_pad[base_i : base_i + dist.get_world_size()],
241
+ params_pad[base_i + dist.get_rank()],
242
+ )
243
+ else:
244
+ for p in group["params"]:
245
+ if p.grad is None:
246
+ # continue
247
+ p.grad = torch.zeros_like(p) # Force synchronization
248
+ state = self.state[p]
249
+ if len(state) == 0:
250
+ state["exp_avg"] = torch.zeros_like(p)
251
+ state["exp_avg_sq"] = torch.zeros_like(p)
252
+ state["step"] = 0
253
+ state["step"] += 1
254
+ update = adam_update(
255
+ p.grad,
256
+ state["exp_avg"],
257
+ state["exp_avg_sq"],
258
+ state["step"],
259
+ group["betas"],
260
+ group["eps"],
261
+ )
262
+ p.mul_(1 - group["lr"] * group["weight_decay"])
263
+ p.add_(update, alpha=-group["lr"])
264
+
265
+ return loss
266
+
267
+
268
+ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
269
+ """
270
+ Non-distributed variant of MuonWithAuxAdam.
271
+ """
272
+
273
+ def __init__(self, param_groups):
274
+ for group in param_groups:
275
+ assert "use_muon" in group
276
+ if group["use_muon"]:
277
+ # defaults
278
+ group["lr"] = group.get("lr", 0.02)
279
+ group["momentum"] = group.get("momentum", 0.95)
280
+ group["weight_decay"] = group.get("weight_decay", 0)
281
+ assert set(group.keys()) == set(
282
+ ["params", "lr", "momentum", "weight_decay", "use_muon"]
283
+ )
284
+ else:
285
+ # defaults
286
+ group["lr"] = group.get("lr", 3e-4)
287
+ group["betas"] = group.get("betas", (0.9, 0.95))
288
+ group["eps"] = group.get("eps", 1e-10)
289
+ group["weight_decay"] = group.get("weight_decay", 0)
290
+ assert set(group.keys()) == set(
291
+ ["params", "lr", "betas", "eps", "weight_decay", "use_muon"]
292
+ )
293
+ super().__init__(param_groups, dict())
294
+
295
+ @torch.no_grad()
296
+ def step(self, closure=None):
297
+
298
+ loss = None
299
+ if closure is not None:
300
+ with torch.enable_grad():
301
+ loss = closure()
302
+
303
+ for group in self.param_groups:
304
+ if group["use_muon"]:
305
+ for p in group["params"]:
306
+ if p.grad is None:
307
+ # continue
308
+ p.grad = torch.zeros_like(p) # Force synchronization
309
+ state = self.state[p]
310
+ if len(state) == 0:
311
+ state["momentum_buffer"] = torch.zeros_like(p)
312
+ update = muon_update(
313
+ p.grad, state["momentum_buffer"], beta=group["momentum"]
314
+ )
315
+ p.mul_(1 - group["lr"] * group["weight_decay"])
316
+ p.add_(update.reshape(p.shape), alpha=-group["lr"])
317
+ else:
318
+ for p in group["params"]:
319
+ if p.grad is None:
320
+ # continue
321
+ p.grad = torch.zeros_like(p) # Force synchronization
322
+ state = self.state[p]
323
+ if len(state) == 0:
324
+ state["exp_avg"] = torch.zeros_like(p)
325
+ state["exp_avg_sq"] = torch.zeros_like(p)
326
+ state["step"] = 0
327
+ state["step"] += 1
328
+ update = adam_update(
329
+ p.grad,
330
+ state["exp_avg"],
331
+ state["exp_avg_sq"],
332
+ state["step"],
333
+ group["betas"],
334
+ group["eps"],
335
+ )
336
+ p.mul_(1 - group["lr"] * group["weight_decay"])
337
+ p.add_(update, alpha=-group["lr"])
338
+
339
+ return loss
@@ -7,11 +7,13 @@ from fusion_bench.utils.lazy_imports import LazyImporter
7
7
  _import_structure = {
8
8
  "base_program": ["BaseHydraProgram"],
9
9
  "fabric_fusion_program": ["FabricModelFusionProgram"],
10
+ "fusion_program": ["ModelFusionProgram"],
10
11
  }
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from .base_program import BaseHydraProgram
14
15
  from .fabric_fusion_program import FabricModelFusionProgram
16
+ from .fusion_program import ModelFusionProgram
15
17
  else:
16
18
  sys.modules[__name__] = LazyImporter(
17
19
  __name__,