magma-optimizer 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- magma_optimizer-0.1.0/.github/workflows/publish.yml +63 -0
- magma_optimizer-0.1.0/.gitignore +17 -0
- magma_optimizer-0.1.0/LICENSE +21 -0
- magma_optimizer-0.1.0/PKG-INFO +17 -0
- magma_optimizer-0.1.0/README.md +78 -0
- magma_optimizer-0.1.0/magma/__init__.py +4 -0
- magma_optimizer-0.1.0/magma/magma.py +184 -0
- magma_optimizer-0.1.0/notebooks/01_quadratic_benchmark.ipynb +401 -0
- magma_optimizer-0.1.0/pyproject.toml +37 -0
- magma_optimizer-0.1.0/tests/__init__.py +0 -0
- magma_optimizer-0.1.0/tests/test_magma.py +152 -0
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
name: Publish to PyPI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
tags:
|
|
6
|
+
- "v*"
|
|
7
|
+
|
|
8
|
+
permissions:
|
|
9
|
+
contents: read
|
|
10
|
+
|
|
11
|
+
jobs:
|
|
12
|
+
test:
|
|
13
|
+
runs-on: ubuntu-latest
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
- uses: actions/setup-python@v5
|
|
17
|
+
with:
|
|
18
|
+
python-version: "3.11"
|
|
19
|
+
- name: Install dependencies
|
|
20
|
+
run: |
|
|
21
|
+
python -m pip install --upgrade pip
|
|
22
|
+
python -m pip install -e ".[dev]"
|
|
23
|
+
- name: Run tests
|
|
24
|
+
run: python -m pytest tests/ -v
|
|
25
|
+
|
|
26
|
+
build:
|
|
27
|
+
needs: test
|
|
28
|
+
runs-on: ubuntu-latest
|
|
29
|
+
steps:
|
|
30
|
+
- uses: actions/checkout@v4
|
|
31
|
+
- uses: actions/setup-python@v5
|
|
32
|
+
with:
|
|
33
|
+
python-version: "3.11"
|
|
34
|
+
- name: Install build tools
|
|
35
|
+
run: python -m pip install --upgrade pip build
|
|
36
|
+
- name: Build sdist and wheel
|
|
37
|
+
run: python -m build
|
|
38
|
+
- name: Verify version matches tag
|
|
39
|
+
run: |
|
|
40
|
+
TAG="${GITHUB_REF#refs/tags/v}"
|
|
41
|
+
PKG_VERSION=$(python -c "import re; print(re.search(r'__version__ = \"(.*?)\"', open('magma/__init__.py').read()).group(1))")
|
|
42
|
+
if [ "$TAG" != "$PKG_VERSION" ]; then
|
|
43
|
+
echo "ERROR: Tag v$TAG does not match package version $PKG_VERSION"
|
|
44
|
+
exit 1
|
|
45
|
+
fi
|
|
46
|
+
- uses: actions/upload-artifact@v4
|
|
47
|
+
with:
|
|
48
|
+
name: dist
|
|
49
|
+
path: dist/
|
|
50
|
+
|
|
51
|
+
publish:
|
|
52
|
+
needs: build
|
|
53
|
+
runs-on: ubuntu-latest
|
|
54
|
+
environment: pypi
|
|
55
|
+
permissions:
|
|
56
|
+
id-token: write
|
|
57
|
+
steps:
|
|
58
|
+
- uses: actions/download-artifact@v4
|
|
59
|
+
with:
|
|
60
|
+
name: dist
|
|
61
|
+
path: dist/
|
|
62
|
+
- name: Publish to PyPI
|
|
63
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Andrij David
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: magma-optimizer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Momentum-Aligned Gradient Masking — block-wise stochastic masking wrapper for PyTorch optimizers
|
|
5
|
+
Project-URL: Homepage, https://github.com/andrijdavid/magma-optimizer
|
|
6
|
+
Project-URL: Repository, https://github.com/andrijdavid/magma-optimizer
|
|
7
|
+
License: MIT
|
|
8
|
+
License-File: LICENSE
|
|
9
|
+
Classifier: Development Status :: 3 - Alpha
|
|
10
|
+
Classifier: Intended Audience :: Science/Research
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Requires-Python: >=3.9
|
|
15
|
+
Requires-Dist: torch>=2.0
|
|
16
|
+
Provides-Extra: dev
|
|
17
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Magma
|
|
2
|
+
|
|
3
|
+
**Momentum-Aligned Gradient Masking for Adaptive Optimizers**
|
|
4
|
+
|
|
5
|
+
Magma is a lightweight wrapper that applies block-wise stochastic masking to any PyTorch optimizer, modulated by the alignment between gradient momentum and the current gradient. It is an implementation of the algorithm described in *"On Surprising Effectiveness of Masking Updates in Adaptive Optimizers"*[(arXiv 2602.15322)](https://arxiv.org/pdf/2602.15322).
|
|
6
|
+
|
|
7
|
+
The core insight is deceptively simple. At each step, a per-parameter Bernoulli coin flip decides whether to keep or discard the update. Updates that survive are further scaled by a smoothed cosine similarity score between the gradient and its exponential moving average. The base optimizer's internal states i.e Adam's running means or RMSProp's squared gradients are always updated. Only the parameter itself is masked.
|
|
8
|
+
|
|
9
|
+
This acts as a form of implicit regularization, particularly effective under the heterogeneous curvature and heavy-tailed gradient noise characteristic of transformer training.
|
|
10
|
+
|
|
11
|
+
## Installation
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
pip install magma-optimizer
|
|
15
|
+
```
|
|
16
|
+
|
|
17
|
+
Or directly from source:
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install git+https://github.com/andrijdavid/magma-optimizer.git
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Usage
|
|
24
|
+
|
|
25
|
+
Magma wraps any instantiated PyTorch optimizer. The interface mirrors what you already know.
|
|
26
|
+
|
|
27
|
+
```python
|
|
28
|
+
from magma import Magma
|
|
29
|
+
import torch
|
|
30
|
+
|
|
31
|
+
model = ... # your model
|
|
32
|
+
base = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
33
|
+
|
|
34
|
+
optimizer = Magma(
|
|
35
|
+
base,
|
|
36
|
+
mask_prob=0.5, # prob of keeping an update
|
|
37
|
+
tau=2.0, # temperature for the alignment sigmoid
|
|
38
|
+
momentum_beta=0.9, # EMA coefficient for gradient momentum
|
|
39
|
+
alignment_ema=0.9, # EMA coefficient for smoothing the alignment score
|
|
40
|
+
exclude=set(model.embed.parameters()), # skip masking on embeddings
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
for x, y in dataloader:
|
|
44
|
+
optimizer.zero_grad()
|
|
45
|
+
loss = criterion(model(x), y)
|
|
46
|
+
loss.backward()
|
|
47
|
+
optimizer.step()
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
The `exclude` parameter accepts a set of tensors that should bypass masking entirely. The paper recommends excluding embedding layers, as their update dynamics differ from attention and MLP blocks.
|
|
51
|
+
|
|
52
|
+
## Algorithm
|
|
53
|
+
|
|
54
|
+
The procedure, applied at each step for each non-excluded parameter:
|
|
55
|
+
|
|
56
|
+
1. Update momentum EMA: `μ = β·μ + (1−β)·g`
|
|
57
|
+
2. Compute alignment: `s̃ = sigmoid(cosine_similarity(μ, g) / τ)`
|
|
58
|
+
3. Smooth alignment: `s = 0.9·s_prev + 0.1·s̃`
|
|
59
|
+
4. Run the base optimizer step (all internal states update normally)
|
|
60
|
+
5. Sample mask: `m ~ Bernoulli(p)`
|
|
61
|
+
6. Apply: `θ = (s·m)·θ_new + (1 − s·m)·θ_old`
|
|
62
|
+
|
|
63
|
+
When the mask is zero, the parameter reverts to its pre-step value. When the mask is one, the update is scaled by the alignment score. The base optimizer sees every gradient regardless.
|
|
64
|
+
|
|
65
|
+
## Citation
|
|
66
|
+
|
|
67
|
+
```bibtex
|
|
68
|
+
@article{joo2026magma,
|
|
69
|
+
title={On Surprising Effectiveness of Masking Updates in Adaptive Optimizers},
|
|
70
|
+
author={Joo, Taejong and Xia, Wenhan and Kim, Cheolmin and Zhang, Ming and Ie, Eugene},
|
|
71
|
+
journal={arXiv preprint arXiv:2602.15322},
|
|
72
|
+
year={2026}
|
|
73
|
+
}
|
|
74
|
+
```
|
|
75
|
+
|
|
76
|
+
## License
|
|
77
|
+
|
|
78
|
+
MIT
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""Magma: Momentum-Aligned Gradient Masking wrapper for PyTorch optimizers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torch.optim import Optimizer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Magma:
|
|
11
|
+
"""
|
|
12
|
+
A Pytorch optimizer wraper with block-wise stochastic masking
|
|
13
|
+
modulated by momentum-gradient alignment.
|
|
14
|
+
As explained here https://arxiv.org/abs/2602.15322
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
optimizer: Optimizer,
|
|
20
|
+
mask_prob: float = 0.5,
|
|
21
|
+
tau: float = 2.0,
|
|
22
|
+
momentum_beta: float = 0.9,
|
|
23
|
+
alignment_ema: float = 0.9,
|
|
24
|
+
exclude: set[Tensor] | None = None,
|
|
25
|
+
) -> None:
|
|
26
|
+
if not 0.0 <= mask_prob <= 1.0:
|
|
27
|
+
raise ValueError(f"mask_prob must be in [0, 1], got {mask_prob}")
|
|
28
|
+
if tau <= 0.0:
|
|
29
|
+
raise ValueError(f"tau must be positive, got {tau}")
|
|
30
|
+
|
|
31
|
+
self.optimizer = optimizer
|
|
32
|
+
self.mask_prob = mask_prob
|
|
33
|
+
self.tau = tau
|
|
34
|
+
self.momentum_beta = momentum_beta
|
|
35
|
+
self.alignment_ema = alignment_ema
|
|
36
|
+
self._exclude_ids: set[int] = {id(t) for t in (exclude or ())}
|
|
37
|
+
|
|
38
|
+
# Magma-specific per-parameter state keyed by param id for momentum and alignment
|
|
39
|
+
self._state: dict[int, dict[str, Tensor]] = {}
|
|
40
|
+
|
|
41
|
+
def step(self, closure=None):
|
|
42
|
+
"""Perform a single Magma-wrapped optimisation step."""
|
|
43
|
+
|
|
44
|
+
# --- Phase 1: pre-step (momentum, alignment, snapshot) --------
|
|
45
|
+
snapshots: dict[int, Tensor] = {} # id(p) -> p.data clone
|
|
46
|
+
alignment_scores: dict[int, float] = {}
|
|
47
|
+
|
|
48
|
+
for group in self.optimizer.param_groups:
|
|
49
|
+
for p in group["params"]:
|
|
50
|
+
if p.grad is None or id(p) in self._exclude_ids:
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
pid = id(p)
|
|
54
|
+
grad = p.grad.detach()
|
|
55
|
+
|
|
56
|
+
# Lazily initialise per-param Magma state
|
|
57
|
+
if pid not in self._state:
|
|
58
|
+
self._state[pid] = {
|
|
59
|
+
"momentum": torch.zeros_like(p.data),
|
|
60
|
+
"alignment": torch.tensor(1.0, device=p.device),
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
st = self._state[pid]
|
|
64
|
+
|
|
65
|
+
# μ_t = β μ_{t-1} + (1-β) g_t
|
|
66
|
+
st["momentum"].mul_(self.momentum_beta).add_(
|
|
67
|
+
grad, alpha=1.0 - self.momentum_beta
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
# cossim(μ_t, g_t)
|
|
71
|
+
cos = torch.nn.functional.cosine_similarity(
|
|
72
|
+
st["momentum"].flatten().unsqueeze(0),
|
|
73
|
+
grad.flatten().unsqueeze(0),
|
|
74
|
+
).item()
|
|
75
|
+
|
|
76
|
+
# sigmoid(cos/τ)
|
|
77
|
+
s_tilde = torch.sigmoid(
|
|
78
|
+
torch.tensor(cos / self.tau, device=p.device)
|
|
79
|
+
).item()
|
|
80
|
+
|
|
81
|
+
# EMA smoothing
|
|
82
|
+
s_prev = st["alignment"].item()
|
|
83
|
+
s = self.alignment_ema * s_prev + (1.0 - self.alignment_ema) * s_tilde
|
|
84
|
+
st["alignment"].fill_(s)
|
|
85
|
+
|
|
86
|
+
alignment_scores[pid] = s
|
|
87
|
+
|
|
88
|
+
# Save current params
|
|
89
|
+
snapshots[pid] = p.data.clone()
|
|
90
|
+
|
|
91
|
+
loss = self.optimizer.step(closure)
|
|
92
|
+
|
|
93
|
+
for group in self.optimizer.param_groups:
|
|
94
|
+
for p in group["params"]:
|
|
95
|
+
pid = id(p)
|
|
96
|
+
if pid not in snapshots:
|
|
97
|
+
continue
|
|
98
|
+
|
|
99
|
+
# m_t ~ Bernoulli(mask_prob)
|
|
100
|
+
mask = float(torch.bernoulli(torch.tensor(self.mask_prob)).item())
|
|
101
|
+
s = alignment_scores[pid]
|
|
102
|
+
# blend = s_t * m_t
|
|
103
|
+
blend = s * mask
|
|
104
|
+
|
|
105
|
+
if blend == 0.0:
|
|
106
|
+
# Fully revert parameter to pre-step value
|
|
107
|
+
p.data.copy_(snapshots[pid])
|
|
108
|
+
elif blend != 1.0:
|
|
109
|
+
# θ = blend*θ_new + (1-blend)*θ_old
|
|
110
|
+
p.data.mul_(blend).add_(snapshots[pid], alpha=1.0 - blend)
|
|
111
|
+
# blend == 1.0, keep as it is
|
|
112
|
+
|
|
113
|
+
return loss
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def zero_grad(self, set_to_none: bool = True):
|
|
117
|
+
self.optimizer.zero_grad(set_to_none=set_to_none)
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def param_groups(self):
|
|
121
|
+
return self.optimizer.param_groups
|
|
122
|
+
|
|
123
|
+
def add_param_group(self, param_group: dict):
|
|
124
|
+
self.optimizer.add_param_group(param_group)
|
|
125
|
+
|
|
126
|
+
def state_dict(self):
|
|
127
|
+
id_to_key: dict[int, tuple[int, int]] = {}
|
|
128
|
+
for gi, group in enumerate(self.optimizer.param_groups):
|
|
129
|
+
for pi, p in enumerate(group["params"]):
|
|
130
|
+
id_to_key[id(p)] = (gi, pi)
|
|
131
|
+
|
|
132
|
+
magma_state = {}
|
|
133
|
+
for pid, st in self._state.items():
|
|
134
|
+
key = id_to_key.get(pid)
|
|
135
|
+
if key is not None:
|
|
136
|
+
magma_state[key] = {k: v.clone() for k, v in st.items()}
|
|
137
|
+
|
|
138
|
+
return {
|
|
139
|
+
"base": self.optimizer.state_dict(),
|
|
140
|
+
"magma_state": magma_state,
|
|
141
|
+
"mask_prob": self.mask_prob,
|
|
142
|
+
"tau": self.tau,
|
|
143
|
+
"momentum_beta": self.momentum_beta,
|
|
144
|
+
"alignment_ema": self.alignment_ema,
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
def load_state_dict(self, state_dict: dict):
|
|
148
|
+
self.optimizer.load_state_dict(state_dict["base"])
|
|
149
|
+
self.mask_prob = state_dict["mask_prob"]
|
|
150
|
+
self.tau = state_dict["tau"]
|
|
151
|
+
self.momentum_beta = state_dict["momentum_beta"]
|
|
152
|
+
self.alignment_ema = state_dict["alignment_ema"]
|
|
153
|
+
|
|
154
|
+
key_to_id: dict[tuple[int, int], int] = {}
|
|
155
|
+
for gi, group in enumerate(self.optimizer.param_groups):
|
|
156
|
+
for pi, p in enumerate(group["params"]):
|
|
157
|
+
key_to_id[(gi, pi)] = id(p)
|
|
158
|
+
|
|
159
|
+
self._state = {}
|
|
160
|
+
for key, st in state_dict["magma_state"].items():
|
|
161
|
+
key = tuple(key) if isinstance(key, list) else key
|
|
162
|
+
pid = key_to_id.get(key)
|
|
163
|
+
if pid is not None:
|
|
164
|
+
self._state[pid] = {k: v.clone() for k, v in st.items()}
|
|
165
|
+
|
|
166
|
+
def __getattr__(self, name: str):
|
|
167
|
+
# Fallback to base optimizer for anything not on Magma itself
|
|
168
|
+
try:
|
|
169
|
+
return getattr(self.optimizer, name)
|
|
170
|
+
except AttributeError:
|
|
171
|
+
raise AttributeError(
|
|
172
|
+
f"Neither 'Magma' nor the base optimizer have attribute '{name}'"
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
def __repr__(self) -> str:
|
|
176
|
+
return (
|
|
177
|
+
f"Magma(\n"
|
|
178
|
+
f" mask_prob={self.mask_prob},\n"
|
|
179
|
+
f" tau={self.tau},\n"
|
|
180
|
+
f" momentum_beta={self.momentum_beta},\n"
|
|
181
|
+
f" alignment_ema={self.alignment_ema},\n"
|
|
182
|
+
f" base={self.optimizer}\n"
|
|
183
|
+
f")"
|
|
184
|
+
)
|