stac-optimizer 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,137 @@
1
+ name: CI
2
+
3
+ on:
4
+ pull_request:
5
+ push:
6
+ branches:
7
+ - main
8
+ tags:
9
+ - "v*"
10
+
11
+ permissions:
12
+ contents: read
13
+
14
+ jobs:
15
+ test:
16
+ runs-on: ubuntu-latest
17
+
18
+ steps:
19
+ - name: Check out repository
20
+ uses: actions/checkout@v4
21
+ with:
22
+ fetch-depth: 0
23
+
24
+ - name: Set up Python
25
+ uses: actions/setup-python@v5
26
+ with:
27
+ python-version: "3.13"
28
+
29
+ - name: Install dependencies
30
+ run: |
31
+ python -m pip install --upgrade pip
32
+ python -m pip install --extra-index-url https://download.pytorch.org/whl/cpu -e ".[dev]"
33
+
34
+ - name: Run tests
35
+ run: python -m pytest
36
+
37
+ package:
38
+ needs: test
39
+ runs-on: ubuntu-latest
40
+
41
+ steps:
42
+ - name: Check out repository
43
+ uses: actions/checkout@v4
44
+ with:
45
+ fetch-depth: 0
46
+
47
+ - name: Set up Python
48
+ uses: actions/setup-python@v5
49
+ with:
50
+ python-version: "3.13"
51
+
52
+ - name: Install packaging dependencies
53
+ run: |
54
+ python -m pip install --upgrade pip
55
+ python -m pip install build twine
56
+ python -m pip install --extra-index-url https://download.pytorch.org/whl/cpu "torch>=2.10"
57
+
58
+ - name: Build distributions
59
+ run: python -m build
60
+
61
+ - name: Validate package metadata
62
+ run: python -m twine check dist/*
63
+
64
+ - name: Smoke test built wheel
65
+ run: |
66
+ python -m pip install --force-reinstall --no-deps dist/*.whl
67
+ python - <<'PY'
68
+ from torch import nn
69
+ from stac_optimizer import STAC
70
+
71
+ model = nn.Sequential(
72
+ nn.Linear(4, 4),
73
+ nn.ReLU(),
74
+ nn.Linear(4, 2),
75
+ )
76
+ optimizer = STAC(model, last_n_layers=1)
77
+ assert optimizer.partition.cap_layer_names == ("2",)
78
+ PY
79
+
80
+ release:
81
+ if: startsWith(github.ref, 'refs/tags/v')
82
+ needs: [test, package]
83
+ runs-on: ubuntu-latest
84
+ permissions:
85
+ contents: write
86
+ id-token: write
87
+
88
+ steps:
89
+ - name: Check out repository
90
+ uses: actions/checkout@v4
91
+ with:
92
+ fetch-depth: 0
93
+
94
+ - name: Set up Python
95
+ uses: actions/setup-python@v5
96
+ with:
97
+ python-version: "3.13"
98
+
99
+ - name: Install build tooling
100
+ run: |
101
+ python -m pip install --upgrade pip
102
+ python -m pip install build setuptools-scm twine
103
+
104
+ - name: Verify tag matches package version
105
+ env:
106
+ TAG_NAME: ${{ github.ref_name }}
107
+ run: |
108
+ python - <<'PY'
109
+ import os
110
+ from setuptools_scm import get_version
111
+
112
+ tag_name = os.environ["TAG_NAME"]
113
+ tag_version = tag_name.removeprefix("v")
114
+ package_version = get_version(root=".", version_scheme="guess-next-dev", local_scheme="no-local-version")
115
+
116
+ if tag_version != package_version:
117
+ raise SystemExit(
118
+ f"Tag version {tag_version!r} does not match package version {package_version!r}."
119
+ )
120
+ PY
121
+
122
+ - name: Build distributions
123
+ run: python -m build
124
+
125
+ - name: Validate package metadata
126
+ run: python -m twine check dist/*
127
+
128
+ - name: Publish distributions to PyPI
129
+ uses: pypa/gh-action-pypi-publish@release/v1
130
+ with:
131
+ packages-dir: dist/
132
+
133
+ - name: Create GitHub Release
134
+ uses: softprops/action-gh-release@v2
135
+ with:
136
+ generate_release_notes: true
137
+ files: dist/*
@@ -0,0 +1,7 @@
1
+ __pycache__/
2
+ .pytest_cache/
3
+ .venv/
4
+ build/
5
+ dist/
6
+ src/*.egg-info/
7
+ *.egg-info/
@@ -0,0 +1,30 @@
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ ## 0.1.2 - 2026-03-18
6
+
7
+ - Added partition-aware `load_state_dict()` validation so STAC checkpoints fail
8
+ fast when loaded into a mismatched trunk/cap split.
9
+ - Fixed the AdamW cap so `amsgrad=True` now matches PyTorch's AMSGrad behavior.
10
+ - Moved optimizer step regression tests and the effectiveness benchmark to CUDA
11
+ coverage on supported machines.
12
+ - Rewrote the README as Markdown-only documentation and removed tracked SVG
13
+ assets from the repository.
14
+ - Added a GitHub Actions path for PyPI Trusted Publishing on release tags.
15
+
16
+ ## 0.1.1 - 2026-03-18
17
+
18
+ - Added a momentum-accumulating sign trunk that keeps the STAC update sign-based
19
+ while improving convergence stability.
20
+ - Added role-specific hyperparameters for trunk vs. cap learning rates and
21
+ weight decay.
22
+ - Added opt-in non-finite gradient checks and broader optimizer regression
23
+ tests.
24
+ - Tightened package metadata to Python 3.13 and added packaging validation to
25
+ CI.
26
+ - Refreshed the README with diagrams, benchmark notes, and release guidance.
27
+
28
+ ## 0.1.0 - 2026-03-18
29
+
30
+ - Initial public release of the STAC optimizer.
@@ -0,0 +1,232 @@
1
+ Metadata-Version: 2.4
2
+ Name: stac-optimizer
3
+ Version: 0.1.2
4
+ Summary: STAC optimizer with a signSGD trunk and an AdamW cap for the last N trainable layers.
5
+ Project-URL: Homepage, https://github.com/smturtle2/stac-optimizer
6
+ Project-URL: Repository, https://github.com/smturtle2/stac-optimizer
7
+ Project-URL: Issues, https://github.com/smturtle2/stac-optimizer/issues
8
+ Project-URL: Changelog, https://github.com/smturtle2/stac-optimizer/blob/main/CHANGELOG.md
9
+ Keywords: pytorch,optimizer,signsgd,adamw
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Programming Language :: Python :: 3
12
+ Classifier: Programming Language :: Python :: 3 :: Only
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
15
+ Requires-Python: >=3.13
16
+ Description-Content-Type: text/markdown
17
+ Requires-Dist: torch>=2.10
18
+ Provides-Extra: dev
19
+ Requires-Dist: build>=1.2; extra == "dev"
20
+ Requires-Dist: pytest>=8.3; extra == "dev"
21
+ Requires-Dist: setuptools-scm>=8; extra == "dev"
22
+ Requires-Dist: twine>=6.1; extra == "dev"
23
+
24
+ # stac-optimizer
25
+
26
+ STAC stands for SignSGD Trunk, AdamW Cap.
27
+
28
+ It is a PyTorch optimizer for models where you want cheap sign-based updates
29
+ through most of the network, but still want AdamW on the last few trainable
30
+ layers where optimization is often most sensitive. The default trunk is
31
+ `sign(momentum)` rather than plain `sign(grad)` because the momentum-smoothed
32
+ variant is materially more stable in both theory and practice.
33
+
34
+ | Item | Value |
35
+ | --- | --- |
36
+ | Python | `>=3.13` |
37
+ | PyTorch | `>=2.10` |
38
+ | Default split | last `1` trainable layer uses AdamW |
39
+ | Trunk update | sign-based update with momentum smoothing |
40
+ | Cap update | AdamW with decoupled weight decay |
41
+
42
+ ## Why STAC
43
+
44
+ - Keeps the bulk of the model on sign-based updates.
45
+ - Preserves AdamW where late-layer adaptation matters most.
46
+ - Partitions layers deterministically from `model.named_modules()`.
47
+ - Supports separate learning rates and weight decay for trunk and cap.
48
+ - Exposes the chosen partition through `optimizer.partition`.
49
+ - Rejects sparse gradients and dynamic `add_param_group()` explicitly.
50
+
51
+ ## Install
52
+
53
+ ```bash
54
+ python -m pip install .
55
+ ```
56
+
57
+ Development install:
58
+
59
+ ```bash
60
+ python -m pip install -e ".[dev]"
61
+ ```
62
+
63
+ ## Quickstart
64
+
65
+ ```python
66
+ import torch
67
+ from torch import nn
68
+
69
+ from stac_optimizer import STAC
70
+
71
+
72
+ model = nn.Sequential(
73
+ nn.Linear(128, 64),
74
+ nn.ReLU(),
75
+ nn.Linear(64, 32),
76
+ nn.ReLU(),
77
+ nn.Linear(32, 10),
78
+ )
79
+
80
+ optimizer = STAC(
81
+ model,
82
+ lr=1e-3,
83
+ last_n_layers=1,
84
+ trunk_momentum=0.9,
85
+ trunk_lr=8e-4,
86
+ cap_lr=1e-3,
87
+ weight_decay=1e-2,
88
+ error_if_nonfinite=True,
89
+ )
90
+
91
+ inputs = torch.randn(8, 128)
92
+ targets = torch.randn(8, 10)
93
+
94
+ loss = torch.nn.functional.mse_loss(model(inputs), targets)
95
+ loss.backward()
96
+ optimizer.step()
97
+ optimizer.zero_grad(set_to_none=True)
98
+
99
+ print("trunk:", optimizer.partition.trunk_layer_names)
100
+ print("cap:", optimizer.partition.cap_layer_names)
101
+ ```
102
+
103
+ ## Partition Rule
104
+
105
+ STAC walks trainable layers in module registration order and splits them into
106
+ two regions:
107
+
108
+ ```text
109
+ [ earlier trainable layers ................. ][ last N trainable layers ]
110
+ trunk: signSGD-like cap: AdamW
111
+ ```
112
+
113
+ - Layer discovery uses `named_parameters(recurse=False)`.
114
+ - Frozen parameters are skipped when counting layers.
115
+ - Shared parameters are assigned to the first discovered owner.
116
+ - Root-level parameters are exposed as `"<root>"`.
117
+ - `last_n_layers=0` keeps the whole model in the trunk.
118
+ - Oversized `last_n_layers` moves the whole model into the cap.
119
+
120
+ ## Hyperparameters
121
+
122
+ | Argument | Meaning |
123
+ | --- | --- |
124
+ | `lr` | Shared base learning rate. |
125
+ | `trunk_lr`, `cap_lr` | Role-specific learning rates. If `trunk_lr` is omitted in hybrid mode, STAC defaults it to `0.75 * lr`. |
126
+ | `last_n_layers` | Number of final trainable layers that become AdamW. |
127
+ | `trunk_momentum` | EMA factor for the trunk before taking the sign. |
128
+ | `weight_decay` | Shared default decoupled weight decay. |
129
+ | `trunk_weight_decay`, `cap_weight_decay` | Role-specific decoupled weight decay. |
130
+ | `betas`, `eps`, `amsgrad` | AdamW cap hyperparameters. |
131
+ | `maximize` | Maximize instead of minimize. |
132
+ | `error_if_nonfinite` | Raise on `NaN` or `Inf` gradients. |
133
+
134
+ ## Stability Notes
135
+
136
+ The defaults are intentionally conservative:
137
+
138
+ - The trunk uses momentum because sign-only methods are substantially more
139
+ stable when the sign is taken after smoothing. See
140
+ [signSGD with Majority Vote](https://arxiv.org/abs/1810.05291) and
141
+ [Momentum Ensures Convergence of SIGNSGD under Weaker Assumptions](https://proceedings.mlr.press/v202/sun23l.html).
142
+ - The cap uses AdamW-style decoupled weight decay rather than mixing decay
143
+ into the gradient. See
144
+ [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101).
145
+ - Recent analysis shows sign-based methods have different optimization
146
+ tradeoffs from SGD and Adam depending on noise and conditioning, which is
147
+ why STAC exposes both `last_n_layers` and separate trunk/cap learning rates.
148
+ See
149
+ [Exact Risk Curves of SignSGD in Modern Overparameterized Linear Regression](https://proceedings.mlr.press/v267/xiao25c.html).
150
+
151
+ Practical tuning guidance:
152
+
153
+ - If training is noisy or unstable, raise `trunk_momentum` before increasing
154
+ the trunk learning rate.
155
+ - If the model underfits, move more layers into the AdamW cap with a larger
156
+ `last_n_layers`.
157
+ - If the head adapts too slowly, raise `cap_lr` without forcing the entire
158
+ network into AdamW.
159
+
160
+ ## Benchmark Snapshot
161
+
162
+ The repository includes [`examples/toy_benchmark.py`](examples/toy_benchmark.py)
163
+ for a quick sanity check. A representative local run on `Python 3.13.12` and
164
+ `torch 2.10.0+cu126` produced:
165
+
166
+ | Optimizer | Mean final loss |
167
+ | --- | ---: |
168
+ | `STAC` default | `0.033961` |
169
+ | `STAC` with plain sign trunk | `0.107899` |
170
+ | `torch.optim.AdamW` | `0.074642` |
171
+
172
+ This is a sanity benchmark, not a universal ranking. The important signal is
173
+ that the default STAC trunk is meaningfully better than a plain sign trunk on a
174
+ real optimization loop.
175
+
176
+ ## Constraints
177
+
178
+ - Sparse gradients are unsupported in both trunk and cap.
179
+ - `add_param_group()` is intentionally unsupported because STAC derives its
180
+ parameter groups from model structure.
181
+ - The split follows module registration order, not dynamic forward order.
182
+
183
+ ## Verification
184
+
185
+ GitHub Actions automation:
186
+
187
+ - On pull requests and pushes to `main`: CPU-based tests, packaging, and built
188
+ wheel smoke checks.
189
+ - On `v*` tags: version validation, rebuild, `twine check`, PyPI publishing,
190
+ and GitHub Release creation.
191
+
192
+ Local CUDA verification for maintainers before a release:
193
+
194
+ ```bash
195
+ python -m pytest -q
196
+ python -m build
197
+ python -m twine check dist/*
198
+ python examples/toy_benchmark.py
199
+ ```
200
+
201
+ Most recent local CUDA run:
202
+
203
+ - `python -m pytest -q`: `17 passed in 6.45s`
204
+ - `python -m build` and `python -m twine check dist/*`: passed
205
+ - `python examples/toy_benchmark.py`:
206
+ `STAC` default `0.033961`, plain sign trunk `0.107899`, `AdamW` `0.074642`
207
+
208
+ ## Release
209
+
210
+ This repository uses `setuptools-scm`, so release tags must match the package
211
+ version that the workflow computes from the tagged commit.
212
+
213
+ Typical release flow:
214
+
215
+ ```bash
216
+ git push origin main
217
+ git tag v0.1.2
218
+ git push origin v0.1.2
219
+ ```
220
+
221
+ The tag workflow then:
222
+
223
+ 1. Verifies that `vX.Y.Z` matches the computed package version.
224
+ 2. Builds fresh distributions and runs `twine check`.
225
+ 3. Publishes to PyPI via GitHub Actions Trusted Publishing.
226
+ 4. Creates the matching GitHub Release and attaches the built artifacts.
227
+
228
+ Project maintainers must register this repository and
229
+ `.github/workflows/workflow.yml` as a Trusted Publisher on PyPI for the publish
230
+ step to succeed.
231
+
232
+ See [CHANGELOG.md](CHANGELOG.md) for released versions only.
@@ -0,0 +1,209 @@
1
+ # stac-optimizer
2
+
3
+ STAC stands for SignSGD Trunk, AdamW Cap.
4
+
5
+ It is a PyTorch optimizer for models where you want cheap sign-based updates
6
+ through most of the network, but still want AdamW on the last few trainable
7
+ layers where optimization is often most sensitive. The default trunk is
8
+ `sign(momentum)` rather than plain `sign(grad)` because the momentum-smoothed
9
+ variant is materially more stable in both theory and practice.
10
+
11
+ | Item | Value |
12
+ | --- | --- |
13
+ | Python | `>=3.13` |
14
+ | PyTorch | `>=2.10` |
15
+ | Default split | last `1` trainable layer uses AdamW |
16
+ | Trunk update | sign-based update with momentum smoothing |
17
+ | Cap update | AdamW with decoupled weight decay |
18
+
19
+ ## Why STAC
20
+
21
+ - Keeps the bulk of the model on sign-based updates.
22
+ - Preserves AdamW where late-layer adaptation matters most.
23
+ - Partitions layers deterministically from `model.named_modules()`.
24
+ - Supports separate learning rates and weight decay for trunk and cap.
25
+ - Exposes the chosen partition through `optimizer.partition`.
26
+ - Rejects sparse gradients and dynamic `add_param_group()` explicitly.
27
+
28
+ ## Install
29
+
30
+ ```bash
31
+ python -m pip install .
32
+ ```
33
+
34
+ Development install:
35
+
36
+ ```bash
37
+ python -m pip install -e ".[dev]"
38
+ ```
39
+
40
+ ## Quickstart
41
+
42
+ ```python
43
+ import torch
44
+ from torch import nn
45
+
46
+ from stac_optimizer import STAC
47
+
48
+
49
+ model = nn.Sequential(
50
+ nn.Linear(128, 64),
51
+ nn.ReLU(),
52
+ nn.Linear(64, 32),
53
+ nn.ReLU(),
54
+ nn.Linear(32, 10),
55
+ )
56
+
57
+ optimizer = STAC(
58
+ model,
59
+ lr=1e-3,
60
+ last_n_layers=1,
61
+ trunk_momentum=0.9,
62
+ trunk_lr=8e-4,
63
+ cap_lr=1e-3,
64
+ weight_decay=1e-2,
65
+ error_if_nonfinite=True,
66
+ )
67
+
68
+ inputs = torch.randn(8, 128)
69
+ targets = torch.randn(8, 10)
70
+
71
+ loss = torch.nn.functional.mse_loss(model(inputs), targets)
72
+ loss.backward()
73
+ optimizer.step()
74
+ optimizer.zero_grad(set_to_none=True)
75
+
76
+ print("trunk:", optimizer.partition.trunk_layer_names)
77
+ print("cap:", optimizer.partition.cap_layer_names)
78
+ ```
79
+
80
+ ## Partition Rule
81
+
82
+ STAC walks trainable layers in module registration order and splits them into
83
+ two regions:
84
+
85
+ ```text
86
+ [ earlier trainable layers ................. ][ last N trainable layers ]
87
+ trunk: signSGD-like cap: AdamW
88
+ ```
89
+
90
+ - Layer discovery uses `named_parameters(recurse=False)`.
91
+ - Frozen parameters are skipped when counting layers.
92
+ - Shared parameters are assigned to the first discovered owner.
93
+ - Root-level parameters are exposed as `"<root>"`.
94
+ - `last_n_layers=0` keeps the whole model in the trunk.
95
+ - Oversized `last_n_layers` moves the whole model into the cap.
96
+
97
+ ## Hyperparameters
98
+
99
+ | Argument | Meaning |
100
+ | --- | --- |
101
+ | `lr` | Shared base learning rate. |
102
+ | `trunk_lr`, `cap_lr` | Role-specific learning rates. If `trunk_lr` is omitted in hybrid mode, STAC defaults it to `0.75 * lr`. |
103
+ | `last_n_layers` | Number of final trainable layers that become AdamW. |
104
+ | `trunk_momentum` | EMA factor for the trunk before taking the sign. |
105
+ | `weight_decay` | Shared default decoupled weight decay. |
106
+ | `trunk_weight_decay`, `cap_weight_decay` | Role-specific decoupled weight decay. |
107
+ | `betas`, `eps`, `amsgrad` | AdamW cap hyperparameters. |
108
+ | `maximize` | Maximize instead of minimize. |
109
+ | `error_if_nonfinite` | Raise on `NaN` or `Inf` gradients. |
110
+
111
+ ## Stability Notes
112
+
113
+ The defaults are intentionally conservative:
114
+
115
+ - The trunk uses momentum because sign-only methods are substantially more
116
+ stable when the sign is taken after smoothing. See
117
+ [signSGD with Majority Vote](https://arxiv.org/abs/1810.05291) and
118
+ [Momentum Ensures Convergence of SIGNSGD under Weaker Assumptions](https://proceedings.mlr.press/v202/sun23l.html).
119
+ - The cap uses AdamW-style decoupled weight decay rather than mixing decay
120
+ into the gradient. See
121
+ [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101).
122
+ - Recent analysis shows sign-based methods have different optimization
123
+ tradeoffs from SGD and Adam depending on noise and conditioning, which is
124
+ why STAC exposes both `last_n_layers` and separate trunk/cap learning rates.
125
+ See
126
+ [Exact Risk Curves of SignSGD in Modern Overparameterized Linear Regression](https://proceedings.mlr.press/v267/xiao25c.html).
127
+
128
+ Practical tuning guidance:
129
+
130
+ - If training is noisy or unstable, raise `trunk_momentum` before increasing
131
+ the trunk learning rate.
132
+ - If the model underfits, move more layers into the AdamW cap with a larger
133
+ `last_n_layers`.
134
+ - If the head adapts too slowly, raise `cap_lr` without forcing the entire
135
+ network into AdamW.
136
+
137
+ ## Benchmark Snapshot
138
+
139
+ The repository includes [`examples/toy_benchmark.py`](examples/toy_benchmark.py)
140
+ for a quick sanity check. A representative local run on `Python 3.13.12` and
141
+ `torch 2.10.0+cu126` produced:
142
+
143
+ | Optimizer | Mean final loss |
144
+ | --- | ---: |
145
+ | `STAC` default | `0.033961` |
146
+ | `STAC` with plain sign trunk | `0.107899` |
147
+ | `torch.optim.AdamW` | `0.074642` |
148
+
149
+ This is a sanity benchmark, not a universal ranking. The important signal is
150
+ that the default STAC trunk is meaningfully better than a plain sign trunk on a
151
+ real optimization loop.
152
+
153
+ ## Constraints
154
+
155
+ - Sparse gradients are unsupported in both trunk and cap.
156
+ - `add_param_group()` is intentionally unsupported because STAC derives its
157
+ parameter groups from model structure.
158
+ - The split follows module registration order, not dynamic forward order.
159
+
160
+ ## Verification
161
+
162
+ GitHub Actions automation:
163
+
164
+ - On pull requests and pushes to `main`: CPU-based tests, packaging, and built
165
+ wheel smoke checks.
166
+ - On `v*` tags: version validation, rebuild, `twine check`, PyPI publishing,
167
+ and GitHub Release creation.
168
+
169
+ Local CUDA verification for maintainers before a release:
170
+
171
+ ```bash
172
+ python -m pytest -q
173
+ python -m build
174
+ python -m twine check dist/*
175
+ python examples/toy_benchmark.py
176
+ ```
177
+
178
+ Most recent local CUDA run:
179
+
180
+ - `python -m pytest -q`: `17 passed in 6.45s`
181
+ - `python -m build` and `python -m twine check dist/*`: passed
182
+ - `python examples/toy_benchmark.py`:
183
+ `STAC` default `0.033961`, plain sign trunk `0.107899`, `AdamW` `0.074642`
184
+
185
+ ## Release
186
+
187
+ This repository uses `setuptools-scm`, so release tags must match the package
188
+ version that the workflow computes from the tagged commit.
189
+
190
+ Typical release flow:
191
+
192
+ ```bash
193
+ git push origin main
194
+ git tag v0.1.2
195
+ git push origin v0.1.2
196
+ ```
197
+
198
+ The tag workflow then:
199
+
200
+ 1. Verifies that `vX.Y.Z` matches the computed package version.
201
+ 2. Builds fresh distributions and runs `twine check`.
202
+ 3. Publishes to PyPI via GitHub Actions Trusted Publishing.
203
+ 4. Creates the matching GitHub Release and attaches the built artifacts.
204
+
205
+ Project maintainers must register this repository and
206
+ `.github/workflows/workflow.yml` as a Trusted Publisher on PyPI for the publish
207
+ step to succeed.
208
+
209
+ See [CHANGELOG.md](CHANGELOG.md) for released versions only.
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ import statistics
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from stac_optimizer import STAC
9
+
10
+
11
+ class ToyMLP(nn.Module):
12
+ def __init__(self) -> None:
13
+ super().__init__()
14
+ self.trunk_0 = nn.Linear(16, 32)
15
+ self.trunk_1 = nn.Linear(32, 32)
16
+ self.head = nn.Linear(32, 4)
17
+
18
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
19
+ inputs = torch.relu(self.trunk_0(inputs))
20
+ inputs = torch.relu(self.trunk_1(inputs))
21
+ return self.head(inputs)
22
+
23
+
24
+ def resolve_device() -> torch.device:
25
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+
28
+ def make_batch(
29
+ seed: int,
30
+ *,
31
+ device: torch.device,
32
+ ) -> tuple[torch.Tensor, torch.Tensor]:
33
+ torch.manual_seed(seed)
34
+ inputs = torch.randn(256, 16)
35
+ target_matrix = torch.randn(16, 4)
36
+ targets = inputs @ target_matrix + 0.1 * torch.randn(256, 4)
37
+ return inputs.to(device), targets.to(device)
38
+
39
+
40
+ def run(seed: int, optimizer_kind: str, *, device: torch.device) -> float:
41
+ inputs, targets = make_batch(seed, device=device)
42
+ torch.manual_seed(0)
43
+ model = ToyMLP().to(device)
44
+
45
+ if optimizer_kind == "stac-default":
46
+ optimizer = STAC(model, lr=3e-3, last_n_layers=1, weight_decay=1e-2)
47
+ elif optimizer_kind == "stac-plain":
48
+ optimizer = STAC(
49
+ model,
50
+ lr=3e-3,
51
+ last_n_layers=1,
52
+ trunk_momentum=0.0,
53
+ weight_decay=1e-2,
54
+ )
55
+ elif optimizer_kind == "adamw":
56
+ optimizer = torch.optim.AdamW(
57
+ model.parameters(),
58
+ lr=3e-3,
59
+ weight_decay=1e-2,
60
+ )
61
+ else:
62
+ raise ValueError(f"Unknown optimizer kind: {optimizer_kind}.")
63
+
64
+ for _ in range(200):
65
+ optimizer.zero_grad(set_to_none=True)
66
+ predictions = model(inputs)
67
+ loss = torch.nn.functional.mse_loss(predictions, targets)
68
+ loss.backward()
69
+ optimizer.step()
70
+
71
+ return float(loss.detach())
72
+
73
+
74
+ def main() -> None:
75
+ device = resolve_device()
76
+ print(f"device={device.type} torch={torch.__version__}")
77
+ for optimizer_kind in ("stac-default", "stac-plain", "adamw"):
78
+ losses = [run(seed, optimizer_kind, device=device) for seed in range(5)]
79
+ print(
80
+ optimizer_kind,
81
+ f"mean={statistics.fmean(losses):.6f}",
82
+ f"min={min(losses):.6f}",
83
+ f"max={max(losses):.6f}",
84
+ )
85
+
86
+
87
+ if __name__ == "__main__":
88
+ main()