stac-optimizer 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stac_optimizer-0.1.2/.github/workflows/workflow.yml +137 -0
- stac_optimizer-0.1.2/.gitignore +7 -0
- stac_optimizer-0.1.2/CHANGELOG.md +30 -0
- stac_optimizer-0.1.2/PKG-INFO +232 -0
- stac_optimizer-0.1.2/README.md +209 -0
- stac_optimizer-0.1.2/examples/toy_benchmark.py +88 -0
- stac_optimizer-0.1.2/pyproject.toml +47 -0
- stac_optimizer-0.1.2/setup.cfg +4 -0
- stac_optimizer-0.1.2/src/stac_optimizer/__init__.py +8 -0
- stac_optimizer-0.1.2/src/stac_optimizer/stac.py +461 -0
- stac_optimizer-0.1.2/src/stac_optimizer.egg-info/PKG-INFO +232 -0
- stac_optimizer-0.1.2/src/stac_optimizer.egg-info/SOURCES.txt +14 -0
- stac_optimizer-0.1.2/src/stac_optimizer.egg-info/dependency_links.txt +1 -0
- stac_optimizer-0.1.2/src/stac_optimizer.egg-info/requires.txt +7 -0
- stac_optimizer-0.1.2/src/stac_optimizer.egg-info/top_level.txt +1 -0
- stac_optimizer-0.1.2/tests/test_stac.py +471 -0
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
pull_request:
|
|
5
|
+
push:
|
|
6
|
+
branches:
|
|
7
|
+
- main
|
|
8
|
+
tags:
|
|
9
|
+
- "v*"
|
|
10
|
+
|
|
11
|
+
permissions:
|
|
12
|
+
contents: read
|
|
13
|
+
|
|
14
|
+
jobs:
|
|
15
|
+
test:
|
|
16
|
+
runs-on: ubuntu-latest
|
|
17
|
+
|
|
18
|
+
steps:
|
|
19
|
+
- name: Check out repository
|
|
20
|
+
uses: actions/checkout@v4
|
|
21
|
+
with:
|
|
22
|
+
fetch-depth: 0
|
|
23
|
+
|
|
24
|
+
- name: Set up Python
|
|
25
|
+
uses: actions/setup-python@v5
|
|
26
|
+
with:
|
|
27
|
+
python-version: "3.13"
|
|
28
|
+
|
|
29
|
+
- name: Install dependencies
|
|
30
|
+
run: |
|
|
31
|
+
python -m pip install --upgrade pip
|
|
32
|
+
python -m pip install --extra-index-url https://download.pytorch.org/whl/cpu -e ".[dev]"
|
|
33
|
+
|
|
34
|
+
- name: Run tests
|
|
35
|
+
run: python -m pytest
|
|
36
|
+
|
|
37
|
+
package:
|
|
38
|
+
needs: test
|
|
39
|
+
runs-on: ubuntu-latest
|
|
40
|
+
|
|
41
|
+
steps:
|
|
42
|
+
- name: Check out repository
|
|
43
|
+
uses: actions/checkout@v4
|
|
44
|
+
with:
|
|
45
|
+
fetch-depth: 0
|
|
46
|
+
|
|
47
|
+
- name: Set up Python
|
|
48
|
+
uses: actions/setup-python@v5
|
|
49
|
+
with:
|
|
50
|
+
python-version: "3.13"
|
|
51
|
+
|
|
52
|
+
- name: Install packaging dependencies
|
|
53
|
+
run: |
|
|
54
|
+
python -m pip install --upgrade pip
|
|
55
|
+
python -m pip install build twine
|
|
56
|
+
python -m pip install --extra-index-url https://download.pytorch.org/whl/cpu "torch>=2.10"
|
|
57
|
+
|
|
58
|
+
- name: Build distributions
|
|
59
|
+
run: python -m build
|
|
60
|
+
|
|
61
|
+
- name: Validate package metadata
|
|
62
|
+
run: python -m twine check dist/*
|
|
63
|
+
|
|
64
|
+
- name: Smoke test built wheel
|
|
65
|
+
run: |
|
|
66
|
+
python -m pip install --force-reinstall --no-deps dist/*.whl
|
|
67
|
+
python - <<'PY'
|
|
68
|
+
from torch import nn
|
|
69
|
+
from stac_optimizer import STAC
|
|
70
|
+
|
|
71
|
+
model = nn.Sequential(
|
|
72
|
+
nn.Linear(4, 4),
|
|
73
|
+
nn.ReLU(),
|
|
74
|
+
nn.Linear(4, 2),
|
|
75
|
+
)
|
|
76
|
+
optimizer = STAC(model, last_n_layers=1)
|
|
77
|
+
assert optimizer.partition.cap_layer_names == ("2",)
|
|
78
|
+
PY
|
|
79
|
+
|
|
80
|
+
release:
|
|
81
|
+
if: startsWith(github.ref, 'refs/tags/v')
|
|
82
|
+
needs: [test, package]
|
|
83
|
+
runs-on: ubuntu-latest
|
|
84
|
+
permissions:
|
|
85
|
+
contents: write
|
|
86
|
+
id-token: write
|
|
87
|
+
|
|
88
|
+
steps:
|
|
89
|
+
- name: Check out repository
|
|
90
|
+
uses: actions/checkout@v4
|
|
91
|
+
with:
|
|
92
|
+
fetch-depth: 0
|
|
93
|
+
|
|
94
|
+
- name: Set up Python
|
|
95
|
+
uses: actions/setup-python@v5
|
|
96
|
+
with:
|
|
97
|
+
python-version: "3.13"
|
|
98
|
+
|
|
99
|
+
- name: Install build tooling
|
|
100
|
+
run: |
|
|
101
|
+
python -m pip install --upgrade pip
|
|
102
|
+
python -m pip install build setuptools-scm twine
|
|
103
|
+
|
|
104
|
+
- name: Verify tag matches package version
|
|
105
|
+
env:
|
|
106
|
+
TAG_NAME: ${{ github.ref_name }}
|
|
107
|
+
run: |
|
|
108
|
+
python - <<'PY'
|
|
109
|
+
import os
|
|
110
|
+
from setuptools_scm import get_version
|
|
111
|
+
|
|
112
|
+
tag_name = os.environ["TAG_NAME"]
|
|
113
|
+
tag_version = tag_name.removeprefix("v")
|
|
114
|
+
package_version = get_version(root=".", version_scheme="guess-next-dev", local_scheme="no-local-version")
|
|
115
|
+
|
|
116
|
+
if tag_version != package_version:
|
|
117
|
+
raise SystemExit(
|
|
118
|
+
f"Tag version {tag_version!r} does not match package version {package_version!r}."
|
|
119
|
+
)
|
|
120
|
+
PY
|
|
121
|
+
|
|
122
|
+
- name: Build distributions
|
|
123
|
+
run: python -m build
|
|
124
|
+
|
|
125
|
+
- name: Validate package metadata
|
|
126
|
+
run: python -m twine check dist/*
|
|
127
|
+
|
|
128
|
+
- name: Publish distributions to PyPI
|
|
129
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
130
|
+
with:
|
|
131
|
+
packages-dir: dist/
|
|
132
|
+
|
|
133
|
+
- name: Create GitHub Release
|
|
134
|
+
uses: softprops/action-gh-release@v2
|
|
135
|
+
with:
|
|
136
|
+
generate_release_notes: true
|
|
137
|
+
files: dist/*
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to this project will be documented in this file.
|
|
4
|
+
|
|
5
|
+
## 0.1.2 - 2026-03-18
|
|
6
|
+
|
|
7
|
+
- Added partition-aware `load_state_dict()` validation so STAC checkpoints fail
|
|
8
|
+
fast when loaded into a mismatched trunk/cap split.
|
|
9
|
+
- Fixed the AdamW cap so `amsgrad=True` now matches PyTorch's AMSGrad behavior.
|
|
10
|
+
- Moved optimizer step regression tests and the effectiveness benchmark to CUDA
|
|
11
|
+
coverage on supported machines.
|
|
12
|
+
- Rewrote the README as Markdown-only documentation and removed tracked SVG
|
|
13
|
+
assets from the repository.
|
|
14
|
+
- Added a GitHub Actions path for PyPI Trusted Publishing on release tags.
|
|
15
|
+
|
|
16
|
+
## 0.1.1 - 2026-03-18
|
|
17
|
+
|
|
18
|
+
- Added a momentum-accumulating sign trunk that keeps the STAC update sign-based
|
|
19
|
+
while improving convergence stability.
|
|
20
|
+
- Added role-specific hyperparameters for trunk vs. cap learning rates and
|
|
21
|
+
weight decay.
|
|
22
|
+
- Added opt-in non-finite gradient checks and broader optimizer regression
|
|
23
|
+
tests.
|
|
24
|
+
- Tightened package metadata to Python 3.13 and added packaging validation to
|
|
25
|
+
CI.
|
|
26
|
+
- Refreshed the README with diagrams, benchmark notes, and release guidance.
|
|
27
|
+
|
|
28
|
+
## 0.1.0 - 2026-03-18
|
|
29
|
+
|
|
30
|
+
- Initial public release of the STAC optimizer.
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: stac-optimizer
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: STAC optimizer with a signSGD trunk and an AdamW cap for the last N trainable layers.
|
|
5
|
+
Project-URL: Homepage, https://github.com/smturtle2/stac-optimizer
|
|
6
|
+
Project-URL: Repository, https://github.com/smturtle2/stac-optimizer
|
|
7
|
+
Project-URL: Issues, https://github.com/smturtle2/stac-optimizer/issues
|
|
8
|
+
Project-URL: Changelog, https://github.com/smturtle2/stac-optimizer/blob/main/CHANGELOG.md
|
|
9
|
+
Keywords: pytorch,optimizer,signsgd,adamw
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: Programming Language :: Python :: 3
|
|
12
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
15
|
+
Requires-Python: >=3.13
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
Requires-Dist: torch>=2.10
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: build>=1.2; extra == "dev"
|
|
20
|
+
Requires-Dist: pytest>=8.3; extra == "dev"
|
|
21
|
+
Requires-Dist: setuptools-scm>=8; extra == "dev"
|
|
22
|
+
Requires-Dist: twine>=6.1; extra == "dev"
|
|
23
|
+
|
|
24
|
+
# stac-optimizer
|
|
25
|
+
|
|
26
|
+
STAC stands for SignSGD Trunk, AdamW Cap.
|
|
27
|
+
|
|
28
|
+
It is a PyTorch optimizer for models where you want cheap sign-based updates
|
|
29
|
+
through most of the network, but still want AdamW on the last few trainable
|
|
30
|
+
layers where optimization is often most sensitive. The default trunk is
|
|
31
|
+
`sign(momentum)` rather than plain `sign(grad)` because the momentum-smoothed
|
|
32
|
+
variant is materially more stable in both theory and practice.
|
|
33
|
+
|
|
34
|
+
| Item | Value |
|
|
35
|
+
| --- | --- |
|
|
36
|
+
| Python | `>=3.13` |
|
|
37
|
+
| PyTorch | `>=2.10` |
|
|
38
|
+
| Default split | last `1` trainable layer uses AdamW |
|
|
39
|
+
| Trunk update | sign-based update with momentum smoothing |
|
|
40
|
+
| Cap update | AdamW with decoupled weight decay |
|
|
41
|
+
|
|
42
|
+
## Why STAC
|
|
43
|
+
|
|
44
|
+
- Keeps the bulk of the model on sign-based updates.
|
|
45
|
+
- Preserves AdamW where late-layer adaptation matters most.
|
|
46
|
+
- Partitions layers deterministically from `model.named_modules()`.
|
|
47
|
+
- Supports separate learning rates and weight decay for trunk and cap.
|
|
48
|
+
- Exposes the chosen partition through `optimizer.partition`.
|
|
49
|
+
- Rejects sparse gradients and dynamic `add_param_group()` explicitly.
|
|
50
|
+
|
|
51
|
+
## Install
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
python -m pip install .
|
|
55
|
+
```
|
|
56
|
+
|
|
57
|
+
Development install:
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
python -m pip install -e ".[dev]"
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
## Quickstart
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
import torch
|
|
67
|
+
from torch import nn
|
|
68
|
+
|
|
69
|
+
from stac_optimizer import STAC
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
model = nn.Sequential(
|
|
73
|
+
nn.Linear(128, 64),
|
|
74
|
+
nn.ReLU(),
|
|
75
|
+
nn.Linear(64, 32),
|
|
76
|
+
nn.ReLU(),
|
|
77
|
+
nn.Linear(32, 10),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
optimizer = STAC(
|
|
81
|
+
model,
|
|
82
|
+
lr=1e-3,
|
|
83
|
+
last_n_layers=1,
|
|
84
|
+
trunk_momentum=0.9,
|
|
85
|
+
trunk_lr=8e-4,
|
|
86
|
+
cap_lr=1e-3,
|
|
87
|
+
weight_decay=1e-2,
|
|
88
|
+
error_if_nonfinite=True,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
inputs = torch.randn(8, 128)
|
|
92
|
+
targets = torch.randn(8, 10)
|
|
93
|
+
|
|
94
|
+
loss = torch.nn.functional.mse_loss(model(inputs), targets)
|
|
95
|
+
loss.backward()
|
|
96
|
+
optimizer.step()
|
|
97
|
+
optimizer.zero_grad(set_to_none=True)
|
|
98
|
+
|
|
99
|
+
print("trunk:", optimizer.partition.trunk_layer_names)
|
|
100
|
+
print("cap:", optimizer.partition.cap_layer_names)
|
|
101
|
+
```
|
|
102
|
+
|
|
103
|
+
## Partition Rule
|
|
104
|
+
|
|
105
|
+
STAC walks trainable layers in module registration order and splits them into
|
|
106
|
+
two regions:
|
|
107
|
+
|
|
108
|
+
```text
|
|
109
|
+
[ earlier trainable layers ................. ][ last N trainable layers ]
|
|
110
|
+
trunk: signSGD-like cap: AdamW
|
|
111
|
+
```
|
|
112
|
+
|
|
113
|
+
- Layer discovery uses `named_parameters(recurse=False)`.
|
|
114
|
+
- Frozen parameters are skipped when counting layers.
|
|
115
|
+
- Shared parameters are assigned to the first discovered owner.
|
|
116
|
+
- Root-level parameters are exposed as `"<root>"`.
|
|
117
|
+
- `last_n_layers=0` keeps the whole model in the trunk.
|
|
118
|
+
- Oversized `last_n_layers` moves the whole model into the cap.
|
|
119
|
+
|
|
120
|
+
## Hyperparameters
|
|
121
|
+
|
|
122
|
+
| Argument | Meaning |
|
|
123
|
+
| --- | --- |
|
|
124
|
+
| `lr` | Shared base learning rate. |
|
|
125
|
+
| `trunk_lr`, `cap_lr` | Role-specific learning rates. If `trunk_lr` is omitted in hybrid mode, STAC defaults it to `0.75 * lr`. |
|
|
126
|
+
| `last_n_layers` | Number of final trainable layers that become AdamW. |
|
|
127
|
+
| `trunk_momentum` | EMA factor for the trunk before taking the sign. |
|
|
128
|
+
| `weight_decay` | Shared default decoupled weight decay. |
|
|
129
|
+
| `trunk_weight_decay`, `cap_weight_decay` | Role-specific decoupled weight decay. |
|
|
130
|
+
| `betas`, `eps`, `amsgrad` | AdamW cap hyperparameters. |
|
|
131
|
+
| `maximize` | Maximize instead of minimize. |
|
|
132
|
+
| `error_if_nonfinite` | Raise on `NaN` or `Inf` gradients. |
|
|
133
|
+
|
|
134
|
+
## Stability Notes
|
|
135
|
+
|
|
136
|
+
The defaults are intentionally conservative:
|
|
137
|
+
|
|
138
|
+
- The trunk uses momentum because sign-only methods are substantially more
|
|
139
|
+
stable when the sign is taken after smoothing. See
|
|
140
|
+
[signSGD with Majority Vote](https://arxiv.org/abs/1810.05291) and
|
|
141
|
+
[Momentum Ensures Convergence of SIGNSGD under Weaker Assumptions](https://proceedings.mlr.press/v202/sun23l.html).
|
|
142
|
+
- The cap uses AdamW-style decoupled weight decay rather than mixing decay
|
|
143
|
+
into the gradient. See
|
|
144
|
+
[Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101).
|
|
145
|
+
- Recent analysis shows sign-based methods have different optimization
|
|
146
|
+
tradeoffs from SGD and Adam depending on noise and conditioning, which is
|
|
147
|
+
why STAC exposes both `last_n_layers` and separate trunk/cap learning rates.
|
|
148
|
+
See
|
|
149
|
+
[Exact Risk Curves of SignSGD in Modern Overparameterized Linear Regression](https://proceedings.mlr.press/v267/xiao25c.html).
|
|
150
|
+
|
|
151
|
+
Practical tuning guidance:
|
|
152
|
+
|
|
153
|
+
- If training is noisy or unstable, raise `trunk_momentum` before increasing
|
|
154
|
+
the trunk learning rate.
|
|
155
|
+
- If the model underfits, move more layers into the AdamW cap with a larger
|
|
156
|
+
`last_n_layers`.
|
|
157
|
+
- If the head adapts too slowly, raise `cap_lr` without forcing the entire
|
|
158
|
+
network into AdamW.
|
|
159
|
+
|
|
160
|
+
## Benchmark Snapshot
|
|
161
|
+
|
|
162
|
+
The repository includes [`examples/toy_benchmark.py`](examples/toy_benchmark.py)
|
|
163
|
+
for a quick sanity check. A representative local run on `Python 3.13.12` and
|
|
164
|
+
`torch 2.10.0+cu126` produced:
|
|
165
|
+
|
|
166
|
+
| Optimizer | Mean final loss |
|
|
167
|
+
| --- | ---: |
|
|
168
|
+
| `STAC` default | `0.033961` |
|
|
169
|
+
| `STAC` with plain sign trunk | `0.107899` |
|
|
170
|
+
| `torch.optim.AdamW` | `0.074642` |
|
|
171
|
+
|
|
172
|
+
This is a sanity benchmark, not a universal ranking. The important signal is
|
|
173
|
+
that the default STAC trunk is meaningfully better than a plain sign trunk on a
|
|
174
|
+
real optimization loop.
|
|
175
|
+
|
|
176
|
+
## Constraints
|
|
177
|
+
|
|
178
|
+
- Sparse gradients are unsupported in both trunk and cap.
|
|
179
|
+
- `add_param_group()` is intentionally unsupported because STAC derives its
|
|
180
|
+
parameter groups from model structure.
|
|
181
|
+
- The split follows module registration order, not dynamic forward order.
|
|
182
|
+
|
|
183
|
+
## Verification
|
|
184
|
+
|
|
185
|
+
GitHub Actions automation:
|
|
186
|
+
|
|
187
|
+
- On pull requests and pushes to `main`: CPU-based tests, packaging, and built
|
|
188
|
+
wheel smoke checks.
|
|
189
|
+
- On `v*` tags: version validation, rebuild, `twine check`, PyPI publishing,
|
|
190
|
+
and GitHub Release creation.
|
|
191
|
+
|
|
192
|
+
Local CUDA verification for maintainers before a release:
|
|
193
|
+
|
|
194
|
+
```bash
|
|
195
|
+
python -m pytest -q
|
|
196
|
+
python -m build
|
|
197
|
+
python -m twine check dist/*
|
|
198
|
+
python examples/toy_benchmark.py
|
|
199
|
+
```
|
|
200
|
+
|
|
201
|
+
Most recent local CUDA run:
|
|
202
|
+
|
|
203
|
+
- `python -m pytest -q`: `17 passed in 6.45s`
|
|
204
|
+
- `python -m build` and `python -m twine check dist/*`: passed
|
|
205
|
+
- `python examples/toy_benchmark.py`:
|
|
206
|
+
`STAC` default `0.033961`, plain sign trunk `0.107899`, `AdamW` `0.074642`
|
|
207
|
+
|
|
208
|
+
## Release
|
|
209
|
+
|
|
210
|
+
This repository uses `setuptools-scm`, so release tags must match the package
|
|
211
|
+
version that the workflow computes from the tagged commit.
|
|
212
|
+
|
|
213
|
+
Typical release flow:
|
|
214
|
+
|
|
215
|
+
```bash
|
|
216
|
+
git push origin main
|
|
217
|
+
git tag v0.1.2
|
|
218
|
+
git push origin v0.1.2
|
|
219
|
+
```
|
|
220
|
+
|
|
221
|
+
The tag workflow then:
|
|
222
|
+
|
|
223
|
+
1. Verifies that `vX.Y.Z` matches the computed package version.
|
|
224
|
+
2. Builds fresh distributions and runs `twine check`.
|
|
225
|
+
3. Publishes to PyPI via GitHub Actions Trusted Publishing.
|
|
226
|
+
4. Creates the matching GitHub Release and attaches the built artifacts.
|
|
227
|
+
|
|
228
|
+
Project maintainers must register this repository and
|
|
229
|
+
`.github/workflows/workflow.yml` as a Trusted Publisher on PyPI for the publish
|
|
230
|
+
step to succeed.
|
|
231
|
+
|
|
232
|
+
See [CHANGELOG.md](CHANGELOG.md) for released versions only.
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
# stac-optimizer
|
|
2
|
+
|
|
3
|
+
STAC stands for SignSGD Trunk, AdamW Cap.
|
|
4
|
+
|
|
5
|
+
It is a PyTorch optimizer for models where you want cheap sign-based updates
|
|
6
|
+
through most of the network, but still want AdamW on the last few trainable
|
|
7
|
+
layers where optimization is often most sensitive. The default trunk is
|
|
8
|
+
`sign(momentum)` rather than plain `sign(grad)` because the momentum-smoothed
|
|
9
|
+
variant is materially more stable in both theory and practice.
|
|
10
|
+
|
|
11
|
+
| Item | Value |
|
|
12
|
+
| --- | --- |
|
|
13
|
+
| Python | `>=3.13` |
|
|
14
|
+
| PyTorch | `>=2.10` |
|
|
15
|
+
| Default split | last `1` trainable layer uses AdamW |
|
|
16
|
+
| Trunk update | sign-based update with momentum smoothing |
|
|
17
|
+
| Cap update | AdamW with decoupled weight decay |
|
|
18
|
+
|
|
19
|
+
## Why STAC
|
|
20
|
+
|
|
21
|
+
- Keeps the bulk of the model on sign-based updates.
|
|
22
|
+
- Preserves AdamW where late-layer adaptation matters most.
|
|
23
|
+
- Partitions layers deterministically from `model.named_modules()`.
|
|
24
|
+
- Supports separate learning rates and weight decay for trunk and cap.
|
|
25
|
+
- Exposes the chosen partition through `optimizer.partition`.
|
|
26
|
+
- Rejects sparse gradients and dynamic `add_param_group()` explicitly.
|
|
27
|
+
|
|
28
|
+
## Install
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
python -m pip install .
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
Development install:
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
python -m pip install -e ".[dev]"
|
|
38
|
+
```
|
|
39
|
+
|
|
40
|
+
## Quickstart
|
|
41
|
+
|
|
42
|
+
```python
|
|
43
|
+
import torch
|
|
44
|
+
from torch import nn
|
|
45
|
+
|
|
46
|
+
from stac_optimizer import STAC
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
model = nn.Sequential(
|
|
50
|
+
nn.Linear(128, 64),
|
|
51
|
+
nn.ReLU(),
|
|
52
|
+
nn.Linear(64, 32),
|
|
53
|
+
nn.ReLU(),
|
|
54
|
+
nn.Linear(32, 10),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
optimizer = STAC(
|
|
58
|
+
model,
|
|
59
|
+
lr=1e-3,
|
|
60
|
+
last_n_layers=1,
|
|
61
|
+
trunk_momentum=0.9,
|
|
62
|
+
trunk_lr=8e-4,
|
|
63
|
+
cap_lr=1e-3,
|
|
64
|
+
weight_decay=1e-2,
|
|
65
|
+
error_if_nonfinite=True,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
inputs = torch.randn(8, 128)
|
|
69
|
+
targets = torch.randn(8, 10)
|
|
70
|
+
|
|
71
|
+
loss = torch.nn.functional.mse_loss(model(inputs), targets)
|
|
72
|
+
loss.backward()
|
|
73
|
+
optimizer.step()
|
|
74
|
+
optimizer.zero_grad(set_to_none=True)
|
|
75
|
+
|
|
76
|
+
print("trunk:", optimizer.partition.trunk_layer_names)
|
|
77
|
+
print("cap:", optimizer.partition.cap_layer_names)
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
## Partition Rule
|
|
81
|
+
|
|
82
|
+
STAC walks trainable layers in module registration order and splits them into
|
|
83
|
+
two regions:
|
|
84
|
+
|
|
85
|
+
```text
|
|
86
|
+
[ earlier trainable layers ................. ][ last N trainable layers ]
|
|
87
|
+
trunk: signSGD-like cap: AdamW
|
|
88
|
+
```
|
|
89
|
+
|
|
90
|
+
- Layer discovery uses `named_parameters(recurse=False)`.
|
|
91
|
+
- Frozen parameters are skipped when counting layers.
|
|
92
|
+
- Shared parameters are assigned to the first discovered owner.
|
|
93
|
+
- Root-level parameters are exposed as `"<root>"`.
|
|
94
|
+
- `last_n_layers=0` keeps the whole model in the trunk.
|
|
95
|
+
- Oversized `last_n_layers` moves the whole model into the cap.
|
|
96
|
+
|
|
97
|
+
## Hyperparameters
|
|
98
|
+
|
|
99
|
+
| Argument | Meaning |
|
|
100
|
+
| --- | --- |
|
|
101
|
+
| `lr` | Shared base learning rate. |
|
|
102
|
+
| `trunk_lr`, `cap_lr` | Role-specific learning rates. If `trunk_lr` is omitted in hybrid mode, STAC defaults it to `0.75 * lr`. |
|
|
103
|
+
| `last_n_layers` | Number of final trainable layers that become AdamW. |
|
|
104
|
+
| `trunk_momentum` | EMA factor for the trunk before taking the sign. |
|
|
105
|
+
| `weight_decay` | Shared default decoupled weight decay. |
|
|
106
|
+
| `trunk_weight_decay`, `cap_weight_decay` | Role-specific decoupled weight decay. |
|
|
107
|
+
| `betas`, `eps`, `amsgrad` | AdamW cap hyperparameters. |
|
|
108
|
+
| `maximize` | Maximize instead of minimize. |
|
|
109
|
+
| `error_if_nonfinite` | Raise on `NaN` or `Inf` gradients. |
|
|
110
|
+
|
|
111
|
+
## Stability Notes
|
|
112
|
+
|
|
113
|
+
The defaults are intentionally conservative:
|
|
114
|
+
|
|
115
|
+
- The trunk uses momentum because sign-only methods are substantially more
|
|
116
|
+
stable when the sign is taken after smoothing. See
|
|
117
|
+
[signSGD with Majority Vote](https://arxiv.org/abs/1810.05291) and
|
|
118
|
+
[Momentum Ensures Convergence of SIGNSGD under Weaker Assumptions](https://proceedings.mlr.press/v202/sun23l.html).
|
|
119
|
+
- The cap uses AdamW-style decoupled weight decay rather than mixing decay
|
|
120
|
+
into the gradient. See
|
|
121
|
+
[Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101).
|
|
122
|
+
- Recent analysis shows sign-based methods have different optimization
|
|
123
|
+
tradeoffs from SGD and Adam depending on noise and conditioning, which is
|
|
124
|
+
why STAC exposes both `last_n_layers` and separate trunk/cap learning rates.
|
|
125
|
+
See
|
|
126
|
+
[Exact Risk Curves of SignSGD in Modern Overparameterized Linear Regression](https://proceedings.mlr.press/v267/xiao25c.html).
|
|
127
|
+
|
|
128
|
+
Practical tuning guidance:
|
|
129
|
+
|
|
130
|
+
- If training is noisy or unstable, raise `trunk_momentum` before increasing
|
|
131
|
+
the trunk learning rate.
|
|
132
|
+
- If the model underfits, move more layers into the AdamW cap with a larger
|
|
133
|
+
`last_n_layers`.
|
|
134
|
+
- If the head adapts too slowly, raise `cap_lr` without forcing the entire
|
|
135
|
+
network into AdamW.
|
|
136
|
+
|
|
137
|
+
## Benchmark Snapshot
|
|
138
|
+
|
|
139
|
+
The repository includes [`examples/toy_benchmark.py`](examples/toy_benchmark.py)
|
|
140
|
+
for a quick sanity check. A representative local run on `Python 3.13.12` and
|
|
141
|
+
`torch 2.10.0+cu126` produced:
|
|
142
|
+
|
|
143
|
+
| Optimizer | Mean final loss |
|
|
144
|
+
| --- | ---: |
|
|
145
|
+
| `STAC` default | `0.033961` |
|
|
146
|
+
| `STAC` with plain sign trunk | `0.107899` |
|
|
147
|
+
| `torch.optim.AdamW` | `0.074642` |
|
|
148
|
+
|
|
149
|
+
This is a sanity benchmark, not a universal ranking. The important signal is
|
|
150
|
+
that the default STAC trunk is meaningfully better than a plain sign trunk on a
|
|
151
|
+
real optimization loop.
|
|
152
|
+
|
|
153
|
+
## Constraints
|
|
154
|
+
|
|
155
|
+
- Sparse gradients are unsupported in both trunk and cap.
|
|
156
|
+
- `add_param_group()` is intentionally unsupported because STAC derives its
|
|
157
|
+
parameter groups from model structure.
|
|
158
|
+
- The split follows module registration order, not dynamic forward order.
|
|
159
|
+
|
|
160
|
+
## Verification
|
|
161
|
+
|
|
162
|
+
GitHub Actions automation:
|
|
163
|
+
|
|
164
|
+
- On pull requests and pushes to `main`: CPU-based tests, packaging, and built
|
|
165
|
+
wheel smoke checks.
|
|
166
|
+
- On `v*` tags: version validation, rebuild, `twine check`, PyPI publishing,
|
|
167
|
+
and GitHub Release creation.
|
|
168
|
+
|
|
169
|
+
Local CUDA verification for maintainers before a release:
|
|
170
|
+
|
|
171
|
+
```bash
|
|
172
|
+
python -m pytest -q
|
|
173
|
+
python -m build
|
|
174
|
+
python -m twine check dist/*
|
|
175
|
+
python examples/toy_benchmark.py
|
|
176
|
+
```
|
|
177
|
+
|
|
178
|
+
Most recent local CUDA run:
|
|
179
|
+
|
|
180
|
+
- `python -m pytest -q`: `17 passed in 6.45s`
|
|
181
|
+
- `python -m build` and `python -m twine check dist/*`: passed
|
|
182
|
+
- `python examples/toy_benchmark.py`:
|
|
183
|
+
`STAC` default `0.033961`, plain sign trunk `0.107899`, `AdamW` `0.074642`
|
|
184
|
+
|
|
185
|
+
## Release
|
|
186
|
+
|
|
187
|
+
This repository uses `setuptools-scm`, so release tags must match the package
|
|
188
|
+
version that the workflow computes from the tagged commit.
|
|
189
|
+
|
|
190
|
+
Typical release flow:
|
|
191
|
+
|
|
192
|
+
```bash
|
|
193
|
+
git push origin main
|
|
194
|
+
git tag v0.1.2
|
|
195
|
+
git push origin v0.1.2
|
|
196
|
+
```
|
|
197
|
+
|
|
198
|
+
The tag workflow then:
|
|
199
|
+
|
|
200
|
+
1. Verifies that `vX.Y.Z` matches the computed package version.
|
|
201
|
+
2. Builds fresh distributions and runs `twine check`.
|
|
202
|
+
3. Publishes to PyPI via GitHub Actions Trusted Publishing.
|
|
203
|
+
4. Creates the matching GitHub Release and attaches the built artifacts.
|
|
204
|
+
|
|
205
|
+
Project maintainers must register this repository and
|
|
206
|
+
`.github/workflows/workflow.yml` as a Trusted Publisher on PyPI for the publish
|
|
207
|
+
step to succeed.
|
|
208
|
+
|
|
209
|
+
See [CHANGELOG.md](CHANGELOG.md) for released versions only.
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import statistics
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from stac_optimizer import STAC
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ToyMLP(nn.Module):
|
|
12
|
+
def __init__(self) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.trunk_0 = nn.Linear(16, 32)
|
|
15
|
+
self.trunk_1 = nn.Linear(32, 32)
|
|
16
|
+
self.head = nn.Linear(32, 4)
|
|
17
|
+
|
|
18
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
19
|
+
inputs = torch.relu(self.trunk_0(inputs))
|
|
20
|
+
inputs = torch.relu(self.trunk_1(inputs))
|
|
21
|
+
return self.head(inputs)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def resolve_device() -> torch.device:
|
|
25
|
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def make_batch(
|
|
29
|
+
seed: int,
|
|
30
|
+
*,
|
|
31
|
+
device: torch.device,
|
|
32
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
33
|
+
torch.manual_seed(seed)
|
|
34
|
+
inputs = torch.randn(256, 16)
|
|
35
|
+
target_matrix = torch.randn(16, 4)
|
|
36
|
+
targets = inputs @ target_matrix + 0.1 * torch.randn(256, 4)
|
|
37
|
+
return inputs.to(device), targets.to(device)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def run(seed: int, optimizer_kind: str, *, device: torch.device) -> float:
|
|
41
|
+
inputs, targets = make_batch(seed, device=device)
|
|
42
|
+
torch.manual_seed(0)
|
|
43
|
+
model = ToyMLP().to(device)
|
|
44
|
+
|
|
45
|
+
if optimizer_kind == "stac-default":
|
|
46
|
+
optimizer = STAC(model, lr=3e-3, last_n_layers=1, weight_decay=1e-2)
|
|
47
|
+
elif optimizer_kind == "stac-plain":
|
|
48
|
+
optimizer = STAC(
|
|
49
|
+
model,
|
|
50
|
+
lr=3e-3,
|
|
51
|
+
last_n_layers=1,
|
|
52
|
+
trunk_momentum=0.0,
|
|
53
|
+
weight_decay=1e-2,
|
|
54
|
+
)
|
|
55
|
+
elif optimizer_kind == "adamw":
|
|
56
|
+
optimizer = torch.optim.AdamW(
|
|
57
|
+
model.parameters(),
|
|
58
|
+
lr=3e-3,
|
|
59
|
+
weight_decay=1e-2,
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(f"Unknown optimizer kind: {optimizer_kind}.")
|
|
63
|
+
|
|
64
|
+
for _ in range(200):
|
|
65
|
+
optimizer.zero_grad(set_to_none=True)
|
|
66
|
+
predictions = model(inputs)
|
|
67
|
+
loss = torch.nn.functional.mse_loss(predictions, targets)
|
|
68
|
+
loss.backward()
|
|
69
|
+
optimizer.step()
|
|
70
|
+
|
|
71
|
+
return float(loss.detach())
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def main() -> None:
|
|
75
|
+
device = resolve_device()
|
|
76
|
+
print(f"device={device.type} torch={torch.__version__}")
|
|
77
|
+
for optimizer_kind in ("stac-default", "stac-plain", "adamw"):
|
|
78
|
+
losses = [run(seed, optimizer_kind, device=device) for seed in range(5)]
|
|
79
|
+
print(
|
|
80
|
+
optimizer_kind,
|
|
81
|
+
f"mean={statistics.fmean(losses):.6f}",
|
|
82
|
+
f"min={min(losses):.6f}",
|
|
83
|
+
f"max={max(losses):.6f}",
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
if __name__ == "__main__":
|
|
88
|
+
main()
|