magma-optimizer 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,63 @@
1
+ name: Publish to PyPI
2
+
3
+ on:
4
+ push:
5
+ tags:
6
+ - "v*"
7
+
8
+ permissions:
9
+ contents: read
10
+
11
+ jobs:
12
+ test:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - uses: actions/setup-python@v5
17
+ with:
18
+ python-version: "3.11"
19
+ - name: Install dependencies
20
+ run: |
21
+ python -m pip install --upgrade pip
22
+ python -m pip install -e ".[dev]"
23
+ - name: Run tests
24
+ run: python -m pytest tests/ -v
25
+
26
+ build:
27
+ needs: test
28
+ runs-on: ubuntu-latest
29
+ steps:
30
+ - uses: actions/checkout@v4
31
+ - uses: actions/setup-python@v5
32
+ with:
33
+ python-version: "3.11"
34
+ - name: Install build tools
35
+ run: python -m pip install --upgrade pip build
36
+ - name: Build sdist and wheel
37
+ run: python -m build
38
+ - name: Verify version matches tag
39
+ run: |
40
+ TAG="${GITHUB_REF#refs/tags/v}"
41
+ PKG_VERSION=$(python -c "import re; print(re.search(r'__version__ = \"(.*?)\"', open('magma/__init__.py').read()).group(1))")
42
+ if [ "$TAG" != "$PKG_VERSION" ]; then
43
+ echo "ERROR: Tag v$TAG does not match package version $PKG_VERSION"
44
+ exit 1
45
+ fi
46
+ - uses: actions/upload-artifact@v4
47
+ with:
48
+ name: dist
49
+ path: dist/
50
+
51
+ publish:
52
+ needs: build
53
+ runs-on: ubuntu-latest
54
+ environment: pypi
55
+ permissions:
56
+ id-token: write
57
+ steps:
58
+ - uses: actions/download-artifact@v4
59
+ with:
60
+ name: dist
61
+ path: dist/
62
+ - name: Publish to PyPI
63
+ uses: pypa/gh-action-pypi-publish@release/v1
@@ -0,0 +1,17 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ *.egg
8
+ .eggs/
9
+ *.so
10
+ .venv/
11
+ venv/
12
+ env/
13
+ .pytest_cache/
14
+ .mypy_cache/
15
+ .ruff_cache/
16
+ .vscode
17
+ data/
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Andrij David
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,17 @@
1
+ Metadata-Version: 2.4
2
+ Name: magma-optimizer
3
+ Version: 0.1.0
4
+ Summary: Momentum-Aligned Gradient Masking — block-wise stochastic masking wrapper for PyTorch optimizers
5
+ Project-URL: Homepage, https://github.com/andrijdavid/magma-optimizer
6
+ Project-URL: Repository, https://github.com/andrijdavid/magma-optimizer
7
+ License: MIT
8
+ License-File: LICENSE
9
+ Classifier: Development Status :: 3 - Alpha
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Requires-Python: >=3.9
15
+ Requires-Dist: torch>=2.0
16
+ Provides-Extra: dev
17
+ Requires-Dist: pytest>=7.0; extra == 'dev'
@@ -0,0 +1,78 @@
1
+ # Magma
2
+
3
+ **Momentum-Aligned Gradient Masking for Adaptive Optimizers**
4
+
5
+ Magma is a lightweight wrapper that applies block-wise stochastic masking to any PyTorch optimizer, modulated by the alignment between gradient momentum and the current gradient. It is an implementation of the algorithm described in *"On Surprising Effectiveness of Masking Updates in Adaptive Optimizers"*[(arXiv 2602.15322)](https://arxiv.org/pdf/2602.15322).
6
+
7
+ The core insight is deceptively simple. At each step, a per-parameter Bernoulli coin flip decides whether to keep or discard the update. Updates that survive are further scaled by a smoothed cosine similarity score between the gradient and its exponential moving average. The base optimizer's internal states i.e Adam's running means or RMSProp's squared gradients are always updated. Only the parameter itself is masked.
8
+
9
+ This acts as a form of implicit regularization, particularly effective under the heterogeneous curvature and heavy-tailed gradient noise characteristic of transformer training.
10
+
11
+ ## Installation
12
+
13
+ ```bash
14
+ pip install magma-optimizer
15
+ ```
16
+
17
+ Or directly from source:
18
+
19
+ ```bash
20
+ pip install git+https://github.com/andrijdavid/magma-optimizer.git
21
+ ```
22
+
23
+ ## Usage
24
+
25
+ Magma wraps any instantiated PyTorch optimizer. The interface mirrors what you already know.
26
+
27
+ ```python
28
+ from magma import Magma
29
+ import torch
30
+
31
+ model = ... # your model
32
+ base = torch.optim.Adam(model.parameters(), lr=1e-3)
33
+
34
+ optimizer = Magma(
35
+ base,
36
+ mask_prob=0.5, # prob of keeping an update
37
+ tau=2.0, # temperature for the alignment sigmoid
38
+ momentum_beta=0.9, # EMA coefficient for gradient momentum
39
+ alignment_ema=0.9, # EMA coefficient for smoothing the alignment score
40
+ exclude=set(model.embed.parameters()), # skip masking on embeddings
41
+ )
42
+
43
+ for x, y in dataloader:
44
+ optimizer.zero_grad()
45
+ loss = criterion(model(x), y)
46
+ loss.backward()
47
+ optimizer.step()
48
+ ```
49
+
50
+ The `exclude` parameter accepts a set of tensors that should bypass masking entirely. The paper recommends excluding embedding layers, as their update dynamics differ from attention and MLP blocks.
51
+
52
+ ## Algorithm
53
+
54
+ The procedure, applied at each step for each non-excluded parameter:
55
+
56
+ 1. Update momentum EMA: `μ = β·μ + (1−β)·g`
57
+ 2. Compute alignment: `s̃ = sigmoid(cosine_similarity(μ, g) / τ)`
58
+ 3. Smooth alignment: `s = 0.9·s_prev + 0.1·s̃`
59
+ 4. Run the base optimizer step (all internal states update normally)
60
+ 5. Sample mask: `m ~ Bernoulli(p)`
61
+ 6. Apply: `θ = (s·m)·θ_new + (1 − s·m)·θ_old`
62
+
63
+ When the mask is zero, the parameter reverts to its pre-step value. When the mask is one, the update is scaled by the alignment score. The base optimizer sees every gradient regardless.
64
+
65
+ ## Citation
66
+
67
+ ```bibtex
68
+ @article{joo2026magma,
69
+ title={On Surprising Effectiveness of Masking Updates in Adaptive Optimizers},
70
+ author={Joo, Taejong and Xia, Wenhan and Kim, Cheolmin and Zhang, Ming and Ie, Eugene},
71
+ journal={arXiv preprint arXiv:2602.15322},
72
+ year={2026}
73
+ }
74
+ ```
75
+
76
+ ## License
77
+
78
+ MIT
@@ -0,0 +1,4 @@
1
+ from magma.magma import Magma
2
+
3
+ __version__ = "0.1.0"
4
+ __all__ = ["Magma"]
@@ -0,0 +1,184 @@
1
+ """Magma: Momentum-Aligned Gradient Masking wrapper for PyTorch optimizers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ from torch.optim import Optimizer
8
+
9
+
10
+ class Magma:
11
+ """
12
+ A Pytorch optimizer wraper with block-wise stochastic masking
13
+ modulated by momentum-gradient alignment.
14
+ As explained here https://arxiv.org/abs/2602.15322
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ optimizer: Optimizer,
20
+ mask_prob: float = 0.5,
21
+ tau: float = 2.0,
22
+ momentum_beta: float = 0.9,
23
+ alignment_ema: float = 0.9,
24
+ exclude: set[Tensor] | None = None,
25
+ ) -> None:
26
+ if not 0.0 <= mask_prob <= 1.0:
27
+ raise ValueError(f"mask_prob must be in [0, 1], got {mask_prob}")
28
+ if tau <= 0.0:
29
+ raise ValueError(f"tau must be positive, got {tau}")
30
+
31
+ self.optimizer = optimizer
32
+ self.mask_prob = mask_prob
33
+ self.tau = tau
34
+ self.momentum_beta = momentum_beta
35
+ self.alignment_ema = alignment_ema
36
+ self._exclude_ids: set[int] = {id(t) for t in (exclude or ())}
37
+
38
+ # Magma-specific per-parameter state keyed by param id for momentum and alignment
39
+ self._state: dict[int, dict[str, Tensor]] = {}
40
+
41
+ def step(self, closure=None):
42
+ """Perform a single Magma-wrapped optimisation step."""
43
+
44
+ # --- Phase 1: pre-step (momentum, alignment, snapshot) --------
45
+ snapshots: dict[int, Tensor] = {} # id(p) -> p.data clone
46
+ alignment_scores: dict[int, float] = {}
47
+
48
+ for group in self.optimizer.param_groups:
49
+ for p in group["params"]:
50
+ if p.grad is None or id(p) in self._exclude_ids:
51
+ continue
52
+
53
+ pid = id(p)
54
+ grad = p.grad.detach()
55
+
56
+ # Lazily initialise per-param Magma state
57
+ if pid not in self._state:
58
+ self._state[pid] = {
59
+ "momentum": torch.zeros_like(p.data),
60
+ "alignment": torch.tensor(1.0, device=p.device),
61
+ }
62
+
63
+ st = self._state[pid]
64
+
65
+ # μ_t = β μ_{t-1} + (1-β) g_t
66
+ st["momentum"].mul_(self.momentum_beta).add_(
67
+ grad, alpha=1.0 - self.momentum_beta
68
+ )
69
+
70
+ # cossim(μ_t, g_t)
71
+ cos = torch.nn.functional.cosine_similarity(
72
+ st["momentum"].flatten().unsqueeze(0),
73
+ grad.flatten().unsqueeze(0),
74
+ ).item()
75
+
76
+ # sigmoid(cos/τ)
77
+ s_tilde = torch.sigmoid(
78
+ torch.tensor(cos / self.tau, device=p.device)
79
+ ).item()
80
+
81
+ # EMA smoothing
82
+ s_prev = st["alignment"].item()
83
+ s = self.alignment_ema * s_prev + (1.0 - self.alignment_ema) * s_tilde
84
+ st["alignment"].fill_(s)
85
+
86
+ alignment_scores[pid] = s
87
+
88
+ # Save current params
89
+ snapshots[pid] = p.data.clone()
90
+
91
+ loss = self.optimizer.step(closure)
92
+
93
+ for group in self.optimizer.param_groups:
94
+ for p in group["params"]:
95
+ pid = id(p)
96
+ if pid not in snapshots:
97
+ continue
98
+
99
+ # m_t ~ Bernoulli(mask_prob)
100
+ mask = float(torch.bernoulli(torch.tensor(self.mask_prob)).item())
101
+ s = alignment_scores[pid]
102
+ # blend = s_t * m_t
103
+ blend = s * mask
104
+
105
+ if blend == 0.0:
106
+ # Fully revert parameter to pre-step value
107
+ p.data.copy_(snapshots[pid])
108
+ elif blend != 1.0:
109
+ # θ = blend*θ_new + (1-blend)*θ_old
110
+ p.data.mul_(blend).add_(snapshots[pid], alpha=1.0 - blend)
111
+ # blend == 1.0, keep as it is
112
+
113
+ return loss
114
+
115
+
116
+ def zero_grad(self, set_to_none: bool = True):
117
+ self.optimizer.zero_grad(set_to_none=set_to_none)
118
+
119
+ @property
120
+ def param_groups(self):
121
+ return self.optimizer.param_groups
122
+
123
+ def add_param_group(self, param_group: dict):
124
+ self.optimizer.add_param_group(param_group)
125
+
126
+ def state_dict(self):
127
+ id_to_key: dict[int, tuple[int, int]] = {}
128
+ for gi, group in enumerate(self.optimizer.param_groups):
129
+ for pi, p in enumerate(group["params"]):
130
+ id_to_key[id(p)] = (gi, pi)
131
+
132
+ magma_state = {}
133
+ for pid, st in self._state.items():
134
+ key = id_to_key.get(pid)
135
+ if key is not None:
136
+ magma_state[key] = {k: v.clone() for k, v in st.items()}
137
+
138
+ return {
139
+ "base": self.optimizer.state_dict(),
140
+ "magma_state": magma_state,
141
+ "mask_prob": self.mask_prob,
142
+ "tau": self.tau,
143
+ "momentum_beta": self.momentum_beta,
144
+ "alignment_ema": self.alignment_ema,
145
+ }
146
+
147
+ def load_state_dict(self, state_dict: dict):
148
+ self.optimizer.load_state_dict(state_dict["base"])
149
+ self.mask_prob = state_dict["mask_prob"]
150
+ self.tau = state_dict["tau"]
151
+ self.momentum_beta = state_dict["momentum_beta"]
152
+ self.alignment_ema = state_dict["alignment_ema"]
153
+
154
+ key_to_id: dict[tuple[int, int], int] = {}
155
+ for gi, group in enumerate(self.optimizer.param_groups):
156
+ for pi, p in enumerate(group["params"]):
157
+ key_to_id[(gi, pi)] = id(p)
158
+
159
+ self._state = {}
160
+ for key, st in state_dict["magma_state"].items():
161
+ key = tuple(key) if isinstance(key, list) else key
162
+ pid = key_to_id.get(key)
163
+ if pid is not None:
164
+ self._state[pid] = {k: v.clone() for k, v in st.items()}
165
+
166
+ def __getattr__(self, name: str):
167
+ # Fallback to base optimizer for anything not on Magma itself
168
+ try:
169
+ return getattr(self.optimizer, name)
170
+ except AttributeError:
171
+ raise AttributeError(
172
+ f"Neither 'Magma' nor the base optimizer have attribute '{name}'"
173
+ )
174
+
175
+ def __repr__(self) -> str:
176
+ return (
177
+ f"Magma(\n"
178
+ f" mask_prob={self.mask_prob},\n"
179
+ f" tau={self.tau},\n"
180
+ f" momentum_beta={self.momentum_beta},\n"
181
+ f" alignment_ema={self.alignment_ema},\n"
182
+ f" base={self.optimizer}\n"
183
+ f")"
184
+ )