lrsched 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lrsched-0.1.0/.github/ISSUE_TEMPLATE/bug_report.md +24 -0
- lrsched-0.1.0/.github/ISSUE_TEMPLATE/feature_request.md +18 -0
- lrsched-0.1.0/.github/PULL_REQUEST_TEMPLATE.md +13 -0
- lrsched-0.1.0/.github/workflows/ci.yml +22 -0
- lrsched-0.1.0/.gitignore +15 -0
- lrsched-0.1.0/CHANGELOG.md +21 -0
- lrsched-0.1.0/CLAUDE.md +47 -0
- lrsched-0.1.0/CODE_OF_CONDUCT.md +37 -0
- lrsched-0.1.0/CONTRIBUTING.md +37 -0
- lrsched-0.1.0/LICENSE +21 -0
- lrsched-0.1.0/PKG-INFO +146 -0
- lrsched-0.1.0/README.md +95 -0
- lrsched-0.1.0/SECURITY.md +19 -0
- lrsched-0.1.0/assets/logo.png +0 -0
- lrsched-0.1.0/docs/architecture.md +50 -0
- lrsched-0.1.0/docs/charter.md +33 -0
- lrsched-0.1.0/docs/logo-prompt.md +10 -0
- lrsched-0.1.0/examples/schedules.py +30 -0
- lrsched-0.1.0/pyproject.toml +61 -0
- lrsched-0.1.0/src/lrsched/__init__.py +34 -0
- lrsched-0.1.0/src/lrsched/_types.py +4 -0
- lrsched-0.1.0/src/lrsched/_validate.py +19 -0
- lrsched-0.1.0/src/lrsched/compose.py +48 -0
- lrsched-0.1.0/src/lrsched/py.typed +0 -0
- lrsched-0.1.0/src/lrsched/schedules.py +170 -0
- lrsched-0.1.0/tests/test_compose.py +30 -0
- lrsched-0.1.0/tests/test_properties.py +31 -0
- lrsched-0.1.0/tests/test_schedules.py +107 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: Bug report
|
|
3
|
+
about: A wrong learning rate or unexpected behavior
|
|
4
|
+
title: ""
|
|
5
|
+
labels: bug
|
|
6
|
+
---
|
|
7
|
+
|
|
8
|
+
## Input
|
|
9
|
+
|
|
10
|
+
The schedule and parameters, for example
|
|
11
|
+
`cosine(base_lr=1e-3, min_lr=1e-5, total_steps=1000)` at step 500.
|
|
12
|
+
|
|
13
|
+
## Expected
|
|
14
|
+
|
|
15
|
+
The learning rate you expected, with a source if it is a known formula.
|
|
16
|
+
|
|
17
|
+
## Actual
|
|
18
|
+
|
|
19
|
+
The value you observed.
|
|
20
|
+
|
|
21
|
+
## Environment
|
|
22
|
+
|
|
23
|
+
- lrsched version:
|
|
24
|
+
- Python version:
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
---
|
|
2
|
+
name: Feature request
|
|
3
|
+
about: Suggest a schedule or capability for lrsched
|
|
4
|
+
title: ""
|
|
5
|
+
labels: enhancement
|
|
6
|
+
---
|
|
7
|
+
|
|
8
|
+
## What
|
|
9
|
+
|
|
10
|
+
The schedule or capability you would like.
|
|
11
|
+
|
|
12
|
+
## Why
|
|
13
|
+
|
|
14
|
+
The training or analysis workflow this would support.
|
|
15
|
+
|
|
16
|
+
## References
|
|
17
|
+
|
|
18
|
+
The paper or formula that pins down the expected behavior.
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
## Summary
|
|
2
|
+
|
|
3
|
+
What this changes and why.
|
|
4
|
+
|
|
5
|
+
## Checklist
|
|
6
|
+
|
|
7
|
+
- [ ] Tests added or updated (a bug fix starts with a failing test)
|
|
8
|
+
- [ ] Exact reference values for any new schedule, plus a shape check
|
|
9
|
+
- [ ] `uv run ruff check .` passes
|
|
10
|
+
- [ ] `uv run mypy src` passes
|
|
11
|
+
- [ ] `uv run pytest` passes
|
|
12
|
+
- [ ] No runtime dependencies added
|
|
13
|
+
- [ ] No em dash characters in docs or commit messages
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
name: CI
|
|
2
|
+
|
|
3
|
+
on:
|
|
4
|
+
push:
|
|
5
|
+
branches: [main]
|
|
6
|
+
pull_request:
|
|
7
|
+
|
|
8
|
+
jobs:
|
|
9
|
+
test:
|
|
10
|
+
runs-on: ubuntu-latest
|
|
11
|
+
strategy:
|
|
12
|
+
matrix:
|
|
13
|
+
python: ["3.10", "3.11", "3.12", "3.13"]
|
|
14
|
+
steps:
|
|
15
|
+
- uses: actions/checkout@v4
|
|
16
|
+
- uses: actions/setup-python@v5
|
|
17
|
+
with:
|
|
18
|
+
python-version: ${{ matrix.python }}
|
|
19
|
+
- run: pip install -e ".[dev]"
|
|
20
|
+
- run: ruff check .
|
|
21
|
+
- run: mypy src
|
|
22
|
+
- run: pytest -q
|
lrsched-0.1.0/.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Changelog
|
|
2
|
+
|
|
3
|
+
All notable changes to this project are documented here. The format follows
|
|
4
|
+
[Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and the project
|
|
5
|
+
adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|
6
|
+
|
|
7
|
+
## [Unreleased]
|
|
8
|
+
|
|
9
|
+
### Planned
|
|
10
|
+
- Cyclical triangular schedules.
|
|
11
|
+
- A small CLI to print or sample a schedule.
|
|
12
|
+
|
|
13
|
+
## [0.1.0]
|
|
14
|
+
|
|
15
|
+
### Added
|
|
16
|
+
- Schedule type (a callable from step to learning rate).
|
|
17
|
+
- Schedules: constant, step_decay, multi_step, exponential, linear, polynomial, cosine,
|
|
18
|
+
cosine_restarts (SGDR), inverse_sqrt (Transformer), one_cycle.
|
|
19
|
+
- Composition: with_warmup, sequential, sample.
|
|
20
|
+
- Validation with clear errors. Test suite with exact reference values and Hypothesis
|
|
21
|
+
invariants.
|
lrsched-0.1.0/CLAUDE.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# lrsched
|
|
2
|
+
|
|
3
|
+
Pure-Python, framework-agnostic learning-rate schedules. Zero runtime dependencies.
|
|
4
|
+
|
|
5
|
+
## Commands
|
|
6
|
+
|
|
7
|
+
- Create env and install: `uv venv && uv pip install -e ".[dev]"`
|
|
8
|
+
- Test: `uv run pytest -q`
|
|
9
|
+
- Lint: `uv run ruff check .` (format with `uv run ruff format .`)
|
|
10
|
+
- Types: `uv run mypy src`
|
|
11
|
+
- Build: `uv build` (then `uv run --with twine twine check dist/*` before publishing)
|
|
12
|
+
|
|
13
|
+
## Architecture
|
|
14
|
+
|
|
15
|
+
`src/lrsched/`:
|
|
16
|
+
- `_types.py` the Schedule type (Callable[[int], float])
|
|
17
|
+
- `_validate.py` shared validation (finite floats, positive ints, step)
|
|
18
|
+
- `schedules.py` schedule factories (cosine, one_cycle, sgdr, inverse_sqrt, and so on)
|
|
19
|
+
- `compose.py` with_warmup, sequential, sample
|
|
20
|
+
- `__init__.py` public surface
|
|
21
|
+
|
|
22
|
+
See `docs/architecture.md` for the formulas.
|
|
23
|
+
|
|
24
|
+
## Conventions
|
|
25
|
+
|
|
26
|
+
- Each schedule factory returns a pure function from step to learning rate.
|
|
27
|
+
- Parameters are required keyword-only arguments; no default values.
|
|
28
|
+
- Schedules hold their final value past the end; a negative step raises.
|
|
29
|
+
- No runtime dependencies. Keep the public API small.
|
|
30
|
+
|
|
31
|
+
## Testing rules
|
|
32
|
+
|
|
33
|
+
- Exact value of each schedule at known steps.
|
|
34
|
+
- Shape checks (restart boundaries, one-cycle peak, warmup handoff).
|
|
35
|
+
- Hypothesis invariants (cosine within bounds, warmup monotone).
|
|
36
|
+
- Bug fixes start with a failing test.
|
|
37
|
+
|
|
38
|
+
## Release
|
|
39
|
+
|
|
40
|
+
- Semantic versioning; update `CHANGELOG.md` and `__version__`.
|
|
41
|
+
- Gates: `uv run pytest && uv run ruff check . && uv run mypy src && uv build && uv run twine check dist/*`.
|
|
42
|
+
- Publish to PyPI, tag `vX.Y.Z`, GitHub release.
|
|
43
|
+
|
|
44
|
+
## Style
|
|
45
|
+
|
|
46
|
+
- No em dash characters in docs, comments, or commit messages.
|
|
47
|
+
- Comments explain non-obvious reasoning only.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Code of Conduct
|
|
2
|
+
|
|
3
|
+
## Our pledge
|
|
4
|
+
|
|
5
|
+
We as members, contributors, and maintainers pledge to make participation in our
|
|
6
|
+
community a harassment-free experience for everyone, regardless of age, body
|
|
7
|
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
|
8
|
+
identity and expression, level of experience, education, socio-economic status,
|
|
9
|
+
nationality, personal appearance, race, religion, or sexual identity and
|
|
10
|
+
orientation.
|
|
11
|
+
|
|
12
|
+
## Our standards
|
|
13
|
+
|
|
14
|
+
Examples of behavior that contributes to a positive environment:
|
|
15
|
+
|
|
16
|
+
- Showing empathy and kindness toward other people.
|
|
17
|
+
- Being respectful of differing opinions, viewpoints, and experiences.
|
|
18
|
+
- Giving and gracefully accepting constructive feedback.
|
|
19
|
+
- Focusing on what is best for the community.
|
|
20
|
+
|
|
21
|
+
Examples of unacceptable behavior:
|
|
22
|
+
|
|
23
|
+
- Harassment, insulting or derogatory comments, and personal or political attacks.
|
|
24
|
+
- Public or private harassment.
|
|
25
|
+
- Publishing others' private information without explicit permission.
|
|
26
|
+
- Other conduct which could reasonably be considered inappropriate.
|
|
27
|
+
|
|
28
|
+
## Enforcement
|
|
29
|
+
|
|
30
|
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
|
31
|
+
reported to the maintainer at amaar2cool@gmail.com. All complaints will be
|
|
32
|
+
reviewed and investigated promptly and fairly.
|
|
33
|
+
|
|
34
|
+
## Attribution
|
|
35
|
+
|
|
36
|
+
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org),
|
|
37
|
+
version 2.1.
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Contributing to lrsched
|
|
2
|
+
|
|
3
|
+
Thanks for your interest. This project values correctness, a small surface area,
|
|
4
|
+
and zero runtime dependencies.
|
|
5
|
+
|
|
6
|
+
## Development
|
|
7
|
+
|
|
8
|
+
```sh
|
|
9
|
+
uv venv
|
|
10
|
+
uv pip install -e ".[dev]"
|
|
11
|
+
uv run pytest -q
|
|
12
|
+
uv run ruff check .
|
|
13
|
+
uv run mypy src
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
A standard virtual environment with `pip install -e ".[dev]"` works the same way.
|
|
17
|
+
|
|
18
|
+
## Guidelines
|
|
19
|
+
|
|
20
|
+
- No runtime dependencies.
|
|
21
|
+
- Each schedule is a pure function from step to learning rate, with required
|
|
22
|
+
keyword-only parameters and no default values.
|
|
23
|
+
- Every schedule needs an exact reference-value test at known steps, a shape check, and
|
|
24
|
+
any invariant it should satisfy.
|
|
25
|
+
- A bug fix starts with a failing test.
|
|
26
|
+
- Run `uv run ruff format .` before committing.
|
|
27
|
+
- Commit messages follow `type(scope): description`.
|
|
28
|
+
|
|
29
|
+
## Adding a schedule
|
|
30
|
+
|
|
31
|
+
State the formula in the docstring and the README, give exact test values at a few steps,
|
|
32
|
+
and document the behavior past the end of training.
|
|
33
|
+
|
|
34
|
+
## Reporting issues
|
|
35
|
+
|
|
36
|
+
Open an issue with the schedule and parameters, the step, the expected learning rate with
|
|
37
|
+
a source, and the value you observed.
|
lrsched-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Amaar Chughtai
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
lrsched-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: lrsched
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Framework-agnostic learning-rate schedules as pure functions, in Python with zero dependencies.
|
|
5
|
+
Project-URL: Homepage, https://github.com/amaar-mc/lrsched
|
|
6
|
+
Project-URL: Repository, https://github.com/amaar-mc/lrsched
|
|
7
|
+
Project-URL: Issues, https://github.com/amaar-mc/lrsched/issues
|
|
8
|
+
Project-URL: Changelog, https://github.com/amaar-mc/lrsched/blob/main/CHANGELOG.md
|
|
9
|
+
Author: Amaar Chughtai
|
|
10
|
+
License: MIT License
|
|
11
|
+
|
|
12
|
+
Copyright (c) 2026 Amaar Chughtai
|
|
13
|
+
|
|
14
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
15
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
16
|
+
in the Software without restriction, including without limitation the rights
|
|
17
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
18
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
19
|
+
furnished to do so, subject to the following conditions:
|
|
20
|
+
|
|
21
|
+
The above copyright notice and this permission notice shall be included in all
|
|
22
|
+
copies or substantial portions of the Software.
|
|
23
|
+
|
|
24
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
25
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
26
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
27
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
28
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
29
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
30
|
+
SOFTWARE.
|
|
31
|
+
License-File: LICENSE
|
|
32
|
+
Keywords: cosine-annealing,deep-learning,learning-rate,lr-schedule,machine-learning,one-cycle,scheduler,sgdr,training,warmup
|
|
33
|
+
Classifier: Development Status :: 4 - Beta
|
|
34
|
+
Classifier: Intended Audience :: Developers
|
|
35
|
+
Classifier: Intended Audience :: Science/Research
|
|
36
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
37
|
+
Classifier: Programming Language :: Python :: 3
|
|
38
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
39
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
40
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
41
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
42
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
43
|
+
Classifier: Typing :: Typed
|
|
44
|
+
Requires-Python: >=3.10
|
|
45
|
+
Provides-Extra: dev
|
|
46
|
+
Requires-Dist: hypothesis>=6; extra == 'dev'
|
|
47
|
+
Requires-Dist: mypy>=1.11; extra == 'dev'
|
|
48
|
+
Requires-Dist: pytest>=8; extra == 'dev'
|
|
49
|
+
Requires-Dist: ruff>=0.6; extra == 'dev'
|
|
50
|
+
Description-Content-Type: text/markdown
|
|
51
|
+
|
|
52
|
+
# lrsched
|
|
53
|
+
|
|
54
|
+
<p align="center">
|
|
55
|
+
<img src="assets/logo.png" alt="lrsched logo" width="160">
|
|
56
|
+
</p>
|
|
57
|
+
|
|
58
|
+
[](https://pypi.org/project/lrsched/)
|
|
59
|
+
[](https://github.com/amaar-mc/lrsched/actions/workflows/ci.yml)
|
|
60
|
+
[](./LICENSE)
|
|
61
|
+
|
|
62
|
+
Framework-agnostic learning-rate schedules as pure functions, in Python with zero dependencies. Each schedule maps a step to a learning rate, so it works in any training loop or framework, or none.
|
|
63
|
+
|
|
64
|
+
## Install
|
|
65
|
+
|
|
66
|
+
```sh
|
|
67
|
+
pip install lrsched
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
## 30-second example
|
|
71
|
+
|
|
72
|
+
```python
|
|
73
|
+
from lrsched import cosine, with_warmup, sample
|
|
74
|
+
|
|
75
|
+
schedule = with_warmup(
|
|
76
|
+
cosine(base_lr=1e-3, min_lr=1e-5, total_steps=1000),
|
|
77
|
+
warmup_steps=100,
|
|
78
|
+
start_lr=0.0,
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
lr = schedule(250) # learning rate at step 250
|
|
82
|
+
curve = sample(schedule, num_steps=1000) # the whole curve, for plotting or logging
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
A schedule is just `Callable[[int], float]`. Plug `schedule(step)` into your optimizer
|
|
86
|
+
however your framework expects, or use it to drive a plain training loop.
|
|
87
|
+
|
|
88
|
+
## Why this exists
|
|
89
|
+
|
|
90
|
+
Every learning-rate scheduler is tied to a framework: `torch.optim.lr_scheduler`,
|
|
91
|
+
`timm`, `transformers`, or `optax` for JAX. If you write a custom loop, use a
|
|
92
|
+
non-PyTorch stack, or just want to plot a schedule, you end up pasting a `LambdaLR`
|
|
93
|
+
snippet. `lrsched` is a small, dependency-free library where each schedule is a pure
|
|
94
|
+
function, easy to test, plot, log, and reuse anywhere.
|
|
95
|
+
|
|
96
|
+
## Comparison
|
|
97
|
+
|
|
98
|
+
| | lrsched | torch / timm | optax |
|
|
99
|
+
|---|:---:|:---:|:---:|
|
|
100
|
+
| Framework | none | PyTorch | JAX |
|
|
101
|
+
| Pure step to lr function | yes | no (optimizer-bound) | partial |
|
|
102
|
+
| Zero dependencies | yes | no | no |
|
|
103
|
+
| Composable warmup and phases | yes | partial | yes |
|
|
104
|
+
|
|
105
|
+
## Schedules
|
|
106
|
+
|
|
107
|
+
- `constant`, `step_decay`, `multi_step`, `exponential`
|
|
108
|
+
- `linear`, `polynomial`
|
|
109
|
+
- `cosine`, `cosine_restarts` (SGDR)
|
|
110
|
+
- `inverse_sqrt` (Transformer)
|
|
111
|
+
- `one_cycle`
|
|
112
|
+
|
|
113
|
+
## Composition
|
|
114
|
+
|
|
115
|
+
- `with_warmup(schedule, ...)` prepends a linear warmup.
|
|
116
|
+
- `sequential(phases)` runs schedules back to back.
|
|
117
|
+
- `sample(schedule, num_steps=...)` evaluates a schedule over a range.
|
|
118
|
+
|
|
119
|
+
Parameters are required keyword arguments, so every schedule reads explicitly at the call
|
|
120
|
+
site. Schedules hold their final value past the end rather than erroring, and a negative
|
|
121
|
+
step raises.
|
|
122
|
+
|
|
123
|
+
## Examples
|
|
124
|
+
|
|
125
|
+
```sh
|
|
126
|
+
python examples/schedules.py
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
## Testing
|
|
130
|
+
|
|
131
|
+
```sh
|
|
132
|
+
pip install -e ".[dev]"
|
|
133
|
+
pytest
|
|
134
|
+
```
|
|
135
|
+
|
|
136
|
+
Tests cover the exact value of each schedule at known steps, schedule-specific shapes
|
|
137
|
+
(restarts, the one-cycle peak, the warmup handoff), and invariants checked with
|
|
138
|
+
Hypothesis (cosine stays within bounds, warmup is monotone).
|
|
139
|
+
|
|
140
|
+
## Contributing
|
|
141
|
+
|
|
142
|
+
Issues and pull requests are welcome. See [CONTRIBUTING.md](./CONTRIBUTING.md).
|
|
143
|
+
|
|
144
|
+
## License
|
|
145
|
+
|
|
146
|
+
MIT. See [LICENSE](./LICENSE).
|
lrsched-0.1.0/README.md
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
# lrsched
|
|
2
|
+
|
|
3
|
+
<p align="center">
|
|
4
|
+
<img src="assets/logo.png" alt="lrsched logo" width="160">
|
|
5
|
+
</p>
|
|
6
|
+
|
|
7
|
+
[](https://pypi.org/project/lrsched/)
|
|
8
|
+
[](https://github.com/amaar-mc/lrsched/actions/workflows/ci.yml)
|
|
9
|
+
[](./LICENSE)
|
|
10
|
+
|
|
11
|
+
Framework-agnostic learning-rate schedules as pure functions, in Python with zero dependencies. Each schedule maps a step to a learning rate, so it works in any training loop or framework, or none.
|
|
12
|
+
|
|
13
|
+
## Install
|
|
14
|
+
|
|
15
|
+
```sh
|
|
16
|
+
pip install lrsched
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## 30-second example
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
from lrsched import cosine, with_warmup, sample
|
|
23
|
+
|
|
24
|
+
schedule = with_warmup(
|
|
25
|
+
cosine(base_lr=1e-3, min_lr=1e-5, total_steps=1000),
|
|
26
|
+
warmup_steps=100,
|
|
27
|
+
start_lr=0.0,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
lr = schedule(250) # learning rate at step 250
|
|
31
|
+
curve = sample(schedule, num_steps=1000) # the whole curve, for plotting or logging
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
A schedule is just `Callable[[int], float]`. Plug `schedule(step)` into your optimizer
|
|
35
|
+
however your framework expects, or use it to drive a plain training loop.
|
|
36
|
+
|
|
37
|
+
## Why this exists
|
|
38
|
+
|
|
39
|
+
Every learning-rate scheduler is tied to a framework: `torch.optim.lr_scheduler`,
|
|
40
|
+
`timm`, `transformers`, or `optax` for JAX. If you write a custom loop, use a
|
|
41
|
+
non-PyTorch stack, or just want to plot a schedule, you end up pasting a `LambdaLR`
|
|
42
|
+
snippet. `lrsched` is a small, dependency-free library where each schedule is a pure
|
|
43
|
+
function, easy to test, plot, log, and reuse anywhere.
|
|
44
|
+
|
|
45
|
+
## Comparison
|
|
46
|
+
|
|
47
|
+
| | lrsched | torch / timm | optax |
|
|
48
|
+
|---|:---:|:---:|:---:|
|
|
49
|
+
| Framework | none | PyTorch | JAX |
|
|
50
|
+
| Pure step to lr function | yes | no (optimizer-bound) | partial |
|
|
51
|
+
| Zero dependencies | yes | no | no |
|
|
52
|
+
| Composable warmup and phases | yes | partial | yes |
|
|
53
|
+
|
|
54
|
+
## Schedules
|
|
55
|
+
|
|
56
|
+
- `constant`, `step_decay`, `multi_step`, `exponential`
|
|
57
|
+
- `linear`, `polynomial`
|
|
58
|
+
- `cosine`, `cosine_restarts` (SGDR)
|
|
59
|
+
- `inverse_sqrt` (Transformer)
|
|
60
|
+
- `one_cycle`
|
|
61
|
+
|
|
62
|
+
## Composition
|
|
63
|
+
|
|
64
|
+
- `with_warmup(schedule, ...)` prepends a linear warmup.
|
|
65
|
+
- `sequential(phases)` runs schedules back to back.
|
|
66
|
+
- `sample(schedule, num_steps=...)` evaluates a schedule over a range.
|
|
67
|
+
|
|
68
|
+
Parameters are required keyword arguments, so every schedule reads explicitly at the call
|
|
69
|
+
site. Schedules hold their final value past the end rather than erroring, and a negative
|
|
70
|
+
step raises.
|
|
71
|
+
|
|
72
|
+
## Examples
|
|
73
|
+
|
|
74
|
+
```sh
|
|
75
|
+
python examples/schedules.py
|
|
76
|
+
```
|
|
77
|
+
|
|
78
|
+
## Testing
|
|
79
|
+
|
|
80
|
+
```sh
|
|
81
|
+
pip install -e ".[dev]"
|
|
82
|
+
pytest
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
Tests cover the exact value of each schedule at known steps, schedule-specific shapes
|
|
86
|
+
(restarts, the one-cycle peak, the warmup handoff), and invariants checked with
|
|
87
|
+
Hypothesis (cosine stays within bounds, warmup is monotone).
|
|
88
|
+
|
|
89
|
+
## Contributing
|
|
90
|
+
|
|
91
|
+
Issues and pull requests are welcome. See [CONTRIBUTING.md](./CONTRIBUTING.md).
|
|
92
|
+
|
|
93
|
+
## License
|
|
94
|
+
|
|
95
|
+
MIT. See [LICENSE](./LICENSE).
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Security Policy
|
|
2
|
+
|
|
3
|
+
## Scope
|
|
4
|
+
|
|
5
|
+
`lrsched` is a pure computation library with no runtime dependencies, no network
|
|
6
|
+
access, and no file system access. The attack surface is limited to incorrect
|
|
7
|
+
results from malformed input, which the library guards against with explicit
|
|
8
|
+
validation.
|
|
9
|
+
|
|
10
|
+
## Reporting a vulnerability
|
|
11
|
+
|
|
12
|
+
If you find a security issue, please email amaar2cool@gmail.com with details and
|
|
13
|
+
steps to reproduce. Please do not open a public issue for security reports. You
|
|
14
|
+
can expect an initial response within a few days.
|
|
15
|
+
|
|
16
|
+
## Supported versions
|
|
17
|
+
|
|
18
|
+
The latest published minor version receives fixes. Pre-1.0 releases may introduce
|
|
19
|
+
breaking changes in minor versions, as allowed by semantic versioning.
|
|
Binary file
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Architecture
|
|
2
|
+
|
|
3
|
+
`lrsched` is a few small pure modules. The core idea is that a schedule is a function
|
|
4
|
+
`Callable[[int], float]` from step to learning rate, built by a factory that captures its
|
|
5
|
+
parameters in a closure.
|
|
6
|
+
|
|
7
|
+
## Modules
|
|
8
|
+
|
|
9
|
+
- `_types.py` defines `Schedule`.
|
|
10
|
+
- `_validate.py` holds the shared checks (finite floats, positive integers, non-negative
|
|
11
|
+
step) so every factory reports bad input the same way.
|
|
12
|
+
- `schedules.py` holds the factories. Each validates its parameters once, then returns a
|
|
13
|
+
closure that validates the step and computes the rate.
|
|
14
|
+
- `compose.py` holds `with_warmup`, `sequential`, and `sample`, which operate on schedules
|
|
15
|
+
rather than parameters.
|
|
16
|
+
|
|
17
|
+
## Formulas
|
|
18
|
+
|
|
19
|
+
With base learning rate b, minimum m, total steps T, and step s clamped to T where noted:
|
|
20
|
+
|
|
21
|
+
- step_decay: `b * gamma ** (s // step_size)`.
|
|
22
|
+
- multi_step: `b * gamma ** (number of milestones <= s)`.
|
|
23
|
+
- exponential: `b * gamma ** s`.
|
|
24
|
+
- linear: `b + (end - b) * (min(s, T) / T)`.
|
|
25
|
+
- polynomial: `end + (b - end) * (1 - min(s, T) / T) ** power`.
|
|
26
|
+
- cosine: `m + 0.5 * (b - m) * (1 + cos(pi * min(s, T) / T))`.
|
|
27
|
+
- cosine_restarts: the cosine formula within the current cycle, where cycle length starts
|
|
28
|
+
at `period` and is multiplied by `t_mult` after each restart. The current cycle is found
|
|
29
|
+
by stepping through cycle lengths until the step falls inside one.
|
|
30
|
+
- inverse_sqrt: `b * s / warmup` during warmup, then `b * sqrt(warmup / s)`.
|
|
31
|
+
- one_cycle: a cosine ramp from `max_lr / initial_div` up to `max_lr` over the first
|
|
32
|
+
`pct_start` of training, then a cosine anneal down to `max_lr / final_div`.
|
|
33
|
+
|
|
34
|
+
## Composition
|
|
35
|
+
|
|
36
|
+
`with_warmup` reads the wrapped schedule's value at step 0 as the warmup target, ramps
|
|
37
|
+
linearly from `start_lr` to that target over the warmup, then runs the wrapped schedule
|
|
38
|
+
with the step shifted so it begins at its own step 0 when warmup ends. `sequential` maps a
|
|
39
|
+
global step into the active phase and passes a phase-local step to that phase's schedule.
|
|
40
|
+
|
|
41
|
+
## Edges
|
|
42
|
+
|
|
43
|
+
Finite schedules clamp the step to their total, so they hold their final value rather than
|
|
44
|
+
extrapolating past the end. A negative step raises, since it is almost always a bug.
|
|
45
|
+
|
|
46
|
+
## Why pure Python
|
|
47
|
+
|
|
48
|
+
The formulas are small and exact. Implementing them with no dependencies means the same
|
|
49
|
+
schedule runs under PyTorch, JAX, NumPy, or a plain loop, and tests are exact rather than
|
|
50
|
+
approximate.
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Charter
|
|
2
|
+
|
|
3
|
+
## Purpose
|
|
4
|
+
|
|
5
|
+
Provide correct, dependency-free learning-rate schedules as pure functions, so that any
|
|
6
|
+
training loop or framework, or none, can use the same schedule and so that schedules are
|
|
7
|
+
easy to test, plot, and reason about.
|
|
8
|
+
|
|
9
|
+
## Scope
|
|
10
|
+
|
|
11
|
+
- The common schedules: constant, step, multi-step, exponential, linear, polynomial,
|
|
12
|
+
cosine, cosine with warm restarts, inverse square root, and one-cycle.
|
|
13
|
+
- Composition: warmup, phase chaining, and sampling.
|
|
14
|
+
|
|
15
|
+
## Non-goals
|
|
16
|
+
|
|
17
|
+
- An optimizer or training framework. Schedules return a learning rate; the caller
|
|
18
|
+
applies it.
|
|
19
|
+
- Framework coupling. There is no dependency on PyTorch, JAX, or any array library.
|
|
20
|
+
- Stateful schedulers. Schedules are pure functions of the step.
|
|
21
|
+
|
|
22
|
+
## Principles
|
|
23
|
+
|
|
24
|
+
- Correctness first. Each schedule is tested against its exact formula at known steps.
|
|
25
|
+
- Small, stable public API. Pure functions, explicit parameters.
|
|
26
|
+
- Zero runtime dependencies.
|
|
27
|
+
- Predictable edges. Schedules hold their final value past the end; a negative step
|
|
28
|
+
raises.
|
|
29
|
+
|
|
30
|
+
## Audience
|
|
31
|
+
|
|
32
|
+
People writing custom training loops, using non-PyTorch stacks, driving schedules from
|
|
33
|
+
config, or teaching how learning-rate schedules behave.
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Logo prompt
|
|
2
|
+
|
|
3
|
+
The logo in `assets/logo.png` was generated with OpenAI `gpt-image-1` (1024x1024,
|
|
4
|
+
medium quality) from this prompt:
|
|
5
|
+
|
|
6
|
+
> Minimal geometric logo for a learning-rate schedule library. A single smooth curve that
|
|
7
|
+
> rises then falls, like a warmup followed by cosine decay, drawn over faint axes. Flat
|
|
8
|
+
> vector, premium research-lab aesthetic. White background, black curve, one pale blue
|
|
9
|
+
> accent dot at the peak. Subtle thin grid. Centered, generous negative space, no text, no
|
|
10
|
+
> letters, no robots, no neon, crisp and high-signal.
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
"""Sample a few schedules and print them as small ASCII sparklines.
|
|
2
|
+
|
|
3
|
+
Run with: python examples/schedules.py
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from lrsched import cosine, one_cycle, sample, with_warmup
|
|
7
|
+
|
|
8
|
+
BARS = " .:-=+*#%@"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def sparkline(values: list[float]) -> str:
|
|
12
|
+
lo = min(values)
|
|
13
|
+
hi = max(values)
|
|
14
|
+
span = hi - lo or 1.0
|
|
15
|
+
return "".join(BARS[min(len(BARS) - 1, int((v - lo) / span * (len(BARS) - 1)))] for v in values)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
total = 60
|
|
19
|
+
named = {
|
|
20
|
+
"cosine": cosine(base_lr=1.0, min_lr=0.0, total_steps=total),
|
|
21
|
+
"warmup+cosine": with_warmup(
|
|
22
|
+
cosine(base_lr=1.0, min_lr=0.0, total_steps=total - 10), warmup_steps=10, start_lr=0.0
|
|
23
|
+
),
|
|
24
|
+
"one_cycle": one_cycle(
|
|
25
|
+
max_lr=1.0, total_steps=total, pct_start=0.3, initial_div=25.0, final_div=1000.0
|
|
26
|
+
),
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
for name, schedule in named.items():
|
|
30
|
+
print(f"{name:>14} {sparkline(sample(schedule, num_steps=total))}")
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "lrsched"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Framework-agnostic learning-rate schedules as pure functions, in Python with zero dependencies."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = { file = "LICENSE" }
|
|
12
|
+
authors = [{ name = "Amaar Chughtai" }]
|
|
13
|
+
keywords = [
|
|
14
|
+
"learning-rate",
|
|
15
|
+
"scheduler",
|
|
16
|
+
"lr-schedule",
|
|
17
|
+
"cosine-annealing",
|
|
18
|
+
"warmup",
|
|
19
|
+
"one-cycle",
|
|
20
|
+
"sgdr",
|
|
21
|
+
"machine-learning",
|
|
22
|
+
"deep-learning",
|
|
23
|
+
"training",
|
|
24
|
+
]
|
|
25
|
+
classifiers = [
|
|
26
|
+
"Development Status :: 4 - Beta",
|
|
27
|
+
"Intended Audience :: Science/Research",
|
|
28
|
+
"Intended Audience :: Developers",
|
|
29
|
+
"License :: OSI Approved :: MIT License",
|
|
30
|
+
"Programming Language :: Python :: 3",
|
|
31
|
+
"Programming Language :: Python :: 3.10",
|
|
32
|
+
"Programming Language :: Python :: 3.11",
|
|
33
|
+
"Programming Language :: Python :: 3.12",
|
|
34
|
+
"Programming Language :: Python :: 3.13",
|
|
35
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
36
|
+
"Typing :: Typed",
|
|
37
|
+
]
|
|
38
|
+
dependencies = []
|
|
39
|
+
|
|
40
|
+
[project.optional-dependencies]
|
|
41
|
+
dev = ["pytest>=8", "ruff>=0.6", "mypy>=1.11", "hypothesis>=6"]
|
|
42
|
+
|
|
43
|
+
[project.urls]
|
|
44
|
+
Homepage = "https://github.com/amaar-mc/lrsched"
|
|
45
|
+
Repository = "https://github.com/amaar-mc/lrsched"
|
|
46
|
+
Issues = "https://github.com/amaar-mc/lrsched/issues"
|
|
47
|
+
Changelog = "https://github.com/amaar-mc/lrsched/blob/main/CHANGELOG.md"
|
|
48
|
+
|
|
49
|
+
[tool.hatch.build.targets.wheel]
|
|
50
|
+
packages = ["src/lrsched"]
|
|
51
|
+
|
|
52
|
+
[tool.ruff]
|
|
53
|
+
line-length = 100
|
|
54
|
+
src = ["src", "tests"]
|
|
55
|
+
|
|
56
|
+
[tool.ruff.lint]
|
|
57
|
+
select = ["E", "F", "I", "UP", "B", "SIM", "RUF"]
|
|
58
|
+
|
|
59
|
+
[tool.mypy]
|
|
60
|
+
strict = true
|
|
61
|
+
files = ["src"]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Framework-agnostic learning-rate schedules as pure functions, with zero dependencies."""
|
|
2
|
+
|
|
3
|
+
from ._types import Schedule
|
|
4
|
+
from .compose import sample, sequential, with_warmup
|
|
5
|
+
from .schedules import (
|
|
6
|
+
constant,
|
|
7
|
+
cosine,
|
|
8
|
+
cosine_restarts,
|
|
9
|
+
exponential,
|
|
10
|
+
inverse_sqrt,
|
|
11
|
+
linear,
|
|
12
|
+
multi_step,
|
|
13
|
+
one_cycle,
|
|
14
|
+
polynomial,
|
|
15
|
+
step_decay,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"Schedule",
|
|
20
|
+
"constant",
|
|
21
|
+
"cosine",
|
|
22
|
+
"cosine_restarts",
|
|
23
|
+
"exponential",
|
|
24
|
+
"inverse_sqrt",
|
|
25
|
+
"linear",
|
|
26
|
+
"multi_step",
|
|
27
|
+
"one_cycle",
|
|
28
|
+
"polynomial",
|
|
29
|
+
"sample",
|
|
30
|
+
"sequential",
|
|
31
|
+
"step_decay",
|
|
32
|
+
"with_warmup",
|
|
33
|
+
]
|
|
34
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from math import isfinite
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def check_finite(name: str, value: float) -> float:
|
|
5
|
+
if not isfinite(value):
|
|
6
|
+
raise ValueError(f"{name} must be a finite number, received {value!r}")
|
|
7
|
+
return value
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def check_positive_int(name: str, value: int) -> int:
|
|
11
|
+
if not isinstance(value, int) or value <= 0:
|
|
12
|
+
raise ValueError(f"{name} must be a positive integer, received {value!r}")
|
|
13
|
+
return value
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def check_step(step: int) -> int:
|
|
17
|
+
if not isinstance(step, int) or step < 0:
|
|
18
|
+
raise ValueError(f"step must be a non-negative integer, received {step!r}")
|
|
19
|
+
return step
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
|
|
3
|
+
from ._types import Schedule
|
|
4
|
+
from ._validate import check_finite, check_positive_int, check_step
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def with_warmup(schedule: Schedule, *, warmup_steps: int, start_lr: float) -> Schedule:
|
|
8
|
+
"""Prepend a linear warmup from start_lr to the wrapped schedule's value at step 0.
|
|
9
|
+
After warmup, the wrapped schedule runs with its step counted from the warmup end."""
|
|
10
|
+
check_positive_int("warmup_steps", warmup_steps)
|
|
11
|
+
check_finite("start_lr", start_lr)
|
|
12
|
+
target = schedule(0)
|
|
13
|
+
|
|
14
|
+
def wrapped(step: int) -> float:
|
|
15
|
+
check_step(step)
|
|
16
|
+
if step < warmup_steps:
|
|
17
|
+
return start_lr + (target - start_lr) * (step / warmup_steps)
|
|
18
|
+
return schedule(step - warmup_steps)
|
|
19
|
+
|
|
20
|
+
return wrapped
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def sequential(phases: Sequence[tuple[int, Schedule]]) -> Schedule:
|
|
24
|
+
"""Run schedules back to back. Each phase is (duration, schedule), and within a phase
|
|
25
|
+
the schedule sees a step counted from that phase's start. Past the final phase, the
|
|
26
|
+
last phase's final value is held."""
|
|
27
|
+
if len(phases) == 0:
|
|
28
|
+
raise ValueError("sequential requires at least one phase")
|
|
29
|
+
for duration, _ in phases:
|
|
30
|
+
check_positive_int("phase duration", duration)
|
|
31
|
+
|
|
32
|
+
def wrapped(step: int) -> float:
|
|
33
|
+
check_step(step)
|
|
34
|
+
offset = 0
|
|
35
|
+
for duration, schedule in phases:
|
|
36
|
+
if step < offset + duration:
|
|
37
|
+
return schedule(step - offset)
|
|
38
|
+
offset += duration
|
|
39
|
+
last_duration, last_schedule = phases[-1]
|
|
40
|
+
return last_schedule(last_duration - 1)
|
|
41
|
+
|
|
42
|
+
return wrapped
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def sample(schedule: Schedule, *, num_steps: int) -> list[float]:
|
|
46
|
+
"""Evaluate a schedule at steps 0..num_steps-1, useful for plotting or testing."""
|
|
47
|
+
check_positive_int("num_steps", num_steps)
|
|
48
|
+
return [schedule(i) for i in range(num_steps)]
|
|
File without changes
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from math import cos, pi
|
|
3
|
+
|
|
4
|
+
from ._types import Schedule
|
|
5
|
+
from ._validate import check_finite, check_positive_int, check_step
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def constant(*, lr: float) -> Schedule:
|
|
9
|
+
"""A flat learning rate."""
|
|
10
|
+
check_finite("lr", lr)
|
|
11
|
+
|
|
12
|
+
def schedule(step: int) -> float:
|
|
13
|
+
check_step(step)
|
|
14
|
+
return lr
|
|
15
|
+
|
|
16
|
+
return schedule
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def step_decay(*, base_lr: float, gamma: float, step_size: int) -> Schedule:
|
|
20
|
+
"""Multiply by gamma every step_size steps."""
|
|
21
|
+
check_finite("base_lr", base_lr)
|
|
22
|
+
check_finite("gamma", gamma)
|
|
23
|
+
check_positive_int("step_size", step_size)
|
|
24
|
+
|
|
25
|
+
def schedule(step: int) -> float:
|
|
26
|
+
check_step(step)
|
|
27
|
+
return base_lr * gamma ** (step // step_size)
|
|
28
|
+
|
|
29
|
+
return schedule
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def multi_step(*, base_lr: float, milestones: Sequence[int], gamma: float) -> Schedule:
|
|
33
|
+
"""Multiply by gamma at each milestone step."""
|
|
34
|
+
check_finite("base_lr", base_lr)
|
|
35
|
+
check_finite("gamma", gamma)
|
|
36
|
+
ordered = sorted(milestones)
|
|
37
|
+
for m in ordered:
|
|
38
|
+
check_positive_int("milestone", m)
|
|
39
|
+
|
|
40
|
+
def schedule(step: int) -> float:
|
|
41
|
+
check_step(step)
|
|
42
|
+
drops = sum(1 for m in ordered if step >= m)
|
|
43
|
+
return base_lr * gamma**drops
|
|
44
|
+
|
|
45
|
+
return schedule
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def exponential(*, base_lr: float, gamma: float) -> Schedule:
|
|
49
|
+
"""Decay by a constant factor gamma each step: base_lr * gamma ** step."""
|
|
50
|
+
check_finite("base_lr", base_lr)
|
|
51
|
+
check_finite("gamma", gamma)
|
|
52
|
+
|
|
53
|
+
def schedule(step: int) -> float:
|
|
54
|
+
check_step(step)
|
|
55
|
+
return base_lr * gamma**step
|
|
56
|
+
|
|
57
|
+
return schedule
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def polynomial(*, base_lr: float, end_lr: float, total_steps: int, power: float) -> Schedule:
|
|
61
|
+
"""Polynomial decay from base_lr to end_lr over total_steps, then hold end_lr."""
|
|
62
|
+
check_finite("base_lr", base_lr)
|
|
63
|
+
check_finite("end_lr", end_lr)
|
|
64
|
+
check_positive_int("total_steps", total_steps)
|
|
65
|
+
check_finite("power", power)
|
|
66
|
+
|
|
67
|
+
def schedule(step: int) -> float:
|
|
68
|
+
check_step(step)
|
|
69
|
+
t = min(step, total_steps)
|
|
70
|
+
remaining = 1.0 - t / total_steps
|
|
71
|
+
return float(end_lr + (base_lr - end_lr) * remaining**power)
|
|
72
|
+
|
|
73
|
+
return schedule
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def linear(*, base_lr: float, end_lr: float, total_steps: int) -> Schedule:
|
|
77
|
+
"""Linear interpolation from base_lr to end_lr over total_steps, then hold end_lr."""
|
|
78
|
+
check_finite("base_lr", base_lr)
|
|
79
|
+
check_finite("end_lr", end_lr)
|
|
80
|
+
check_positive_int("total_steps", total_steps)
|
|
81
|
+
|
|
82
|
+
def schedule(step: int) -> float:
|
|
83
|
+
check_step(step)
|
|
84
|
+
t = min(step, total_steps)
|
|
85
|
+
return base_lr + (end_lr - base_lr) * (t / total_steps)
|
|
86
|
+
|
|
87
|
+
return schedule
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def cosine(*, base_lr: float, min_lr: float, total_steps: int) -> Schedule:
|
|
91
|
+
"""Cosine annealing from base_lr at step 0 to min_lr at total_steps, then hold."""
|
|
92
|
+
check_finite("base_lr", base_lr)
|
|
93
|
+
check_finite("min_lr", min_lr)
|
|
94
|
+
check_positive_int("total_steps", total_steps)
|
|
95
|
+
|
|
96
|
+
def schedule(step: int) -> float:
|
|
97
|
+
check_step(step)
|
|
98
|
+
t = min(step, total_steps)
|
|
99
|
+
return min_lr + 0.5 * (base_lr - min_lr) * (1.0 + cos(pi * t / total_steps))
|
|
100
|
+
|
|
101
|
+
return schedule
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def cosine_restarts(*, base_lr: float, min_lr: float, period: int, t_mult: int) -> Schedule:
|
|
105
|
+
"""Cosine annealing with warm restarts (SGDR). Each cycle resets to base_lr; cycle
|
|
106
|
+
length is multiplied by t_mult after each restart (t_mult = 1 keeps it periodic)."""
|
|
107
|
+
check_finite("base_lr", base_lr)
|
|
108
|
+
check_finite("min_lr", min_lr)
|
|
109
|
+
check_positive_int("period", period)
|
|
110
|
+
check_positive_int("t_mult", t_mult)
|
|
111
|
+
|
|
112
|
+
def schedule(step: int) -> float:
|
|
113
|
+
check_step(step)
|
|
114
|
+
start = 0
|
|
115
|
+
current = period
|
|
116
|
+
while step >= start + current:
|
|
117
|
+
start += current
|
|
118
|
+
current *= t_mult
|
|
119
|
+
t = step - start
|
|
120
|
+
return min_lr + 0.5 * (base_lr - min_lr) * (1.0 + cos(pi * t / current))
|
|
121
|
+
|
|
122
|
+
return schedule
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def inverse_sqrt(*, base_lr: float, warmup_steps: int) -> Schedule:
|
|
126
|
+
"""Transformer schedule: linear warmup to base_lr at warmup_steps, then decay by the
|
|
127
|
+
inverse square root of the step."""
|
|
128
|
+
check_finite("base_lr", base_lr)
|
|
129
|
+
check_positive_int("warmup_steps", warmup_steps)
|
|
130
|
+
|
|
131
|
+
def schedule(step: int) -> float:
|
|
132
|
+
check_step(step)
|
|
133
|
+
if step < warmup_steps:
|
|
134
|
+
return base_lr * step / warmup_steps
|
|
135
|
+
return float(base_lr * (warmup_steps / step) ** 0.5)
|
|
136
|
+
|
|
137
|
+
return schedule
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def one_cycle(
|
|
141
|
+
*, max_lr: float, total_steps: int, pct_start: float, initial_div: float, final_div: float
|
|
142
|
+
) -> Schedule:
|
|
143
|
+
"""One-cycle policy: cosine ramp up to max_lr over the first pct_start of training,
|
|
144
|
+
then cosine anneal down to max_lr / final_div. Starts at max_lr / initial_div."""
|
|
145
|
+
check_finite("max_lr", max_lr)
|
|
146
|
+
check_positive_int("total_steps", total_steps)
|
|
147
|
+
check_finite("pct_start", pct_start)
|
|
148
|
+
check_finite("initial_div", initial_div)
|
|
149
|
+
check_finite("final_div", final_div)
|
|
150
|
+
if not 0.0 < pct_start < 1.0:
|
|
151
|
+
raise ValueError(f"pct_start must be between 0 and 1, received {pct_start!r}")
|
|
152
|
+
if initial_div <= 0.0 or final_div <= 0.0:
|
|
153
|
+
raise ValueError("initial_div and final_div must be positive")
|
|
154
|
+
if total_steps < 2:
|
|
155
|
+
raise ValueError("total_steps must be at least 2 for one_cycle")
|
|
156
|
+
|
|
157
|
+
warmup = min(max(1, round(pct_start * total_steps)), total_steps - 1)
|
|
158
|
+
start_lr = max_lr / initial_div
|
|
159
|
+
end_lr = max_lr / final_div
|
|
160
|
+
|
|
161
|
+
def schedule(step: int) -> float:
|
|
162
|
+
check_step(step)
|
|
163
|
+
if step <= warmup:
|
|
164
|
+
frac = step / warmup
|
|
165
|
+
return start_lr + 0.5 * (max_lr - start_lr) * (1.0 - cos(pi * frac))
|
|
166
|
+
t = min(step, total_steps)
|
|
167
|
+
frac = (t - warmup) / (total_steps - warmup)
|
|
168
|
+
return end_lr + 0.5 * (max_lr - end_lr) * (1.0 + cos(pi * frac))
|
|
169
|
+
|
|
170
|
+
return schedule
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from lrsched import constant, cosine, sample, sequential, with_warmup
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_with_warmup() -> None:
|
|
7
|
+
s = with_warmup(cosine(base_lr=1.0, min_lr=0.0, total_steps=10), warmup_steps=5, start_lr=0.0)
|
|
8
|
+
assert s(0) == pytest.approx(0.0)
|
|
9
|
+
assert s(2) == pytest.approx(0.4) # linear ramp 0 -> 1 over 5 steps
|
|
10
|
+
assert s(5) == pytest.approx(1.0) # hands off to cosine at its step 0
|
|
11
|
+
assert s(10) == pytest.approx(cosine(base_lr=1.0, min_lr=0.0, total_steps=10)(5))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def test_sequential() -> None:
|
|
15
|
+
s = sequential([(5, constant(lr=0.1)), (5, constant(lr=0.2))])
|
|
16
|
+
assert s(0) == pytest.approx(0.1)
|
|
17
|
+
assert s(4) == pytest.approx(0.1)
|
|
18
|
+
assert s(5) == pytest.approx(0.2)
|
|
19
|
+
assert s(9) == pytest.approx(0.2)
|
|
20
|
+
assert s(50) == pytest.approx(0.2) # held at the last phase past the end
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_sequential_requires_a_phase() -> None:
|
|
24
|
+
with pytest.raises(ValueError):
|
|
25
|
+
sequential([])
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def test_sample() -> None:
|
|
29
|
+
assert sample(constant(lr=0.5), num_steps=3) == [0.5, 0.5, 0.5]
|
|
30
|
+
assert len(sample(cosine(base_lr=1.0, min_lr=0.0, total_steps=100), num_steps=100)) == 100
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from itertools import pairwise
|
|
2
|
+
|
|
3
|
+
from hypothesis import given
|
|
4
|
+
from hypothesis import strategies as st
|
|
5
|
+
|
|
6
|
+
from lrsched import constant, cosine, sample, with_warmup
|
|
7
|
+
|
|
8
|
+
steps = st.integers(min_value=1, max_value=1000)
|
|
9
|
+
rates = st.floats(min_value=0.0, max_value=10.0, allow_nan=False, allow_infinity=False)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@given(rates, rates, steps, st.integers(min_value=0, max_value=2000))
|
|
13
|
+
def test_cosine_stays_within_bounds(base: float, gap: float, total: int, step: int) -> None:
|
|
14
|
+
base_lr = base + gap
|
|
15
|
+
min_lr = base
|
|
16
|
+
value = cosine(base_lr=base_lr, min_lr=min_lr, total_steps=total)(step)
|
|
17
|
+
assert min_lr - 1e-9 <= value <= base_lr + 1e-9
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@given(steps, rates)
|
|
21
|
+
def test_warmup_is_monotone_nondecreasing(warmup: int, target: float) -> None:
|
|
22
|
+
main = cosine(base_lr=target, min_lr=0.0, total_steps=100)
|
|
23
|
+
s = with_warmup(main, warmup_steps=warmup, start_lr=0.0)
|
|
24
|
+
values = [s(i) for i in range(warmup + 1)]
|
|
25
|
+
for a, b in pairwise(values):
|
|
26
|
+
assert b >= a - 1e-9
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@given(st.integers(min_value=1, max_value=500))
|
|
30
|
+
def test_sample_length(num_steps: int) -> None:
|
|
31
|
+
assert len(sample(constant(lr=0.1), num_steps=num_steps)) == num_steps
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from lrsched import (
|
|
4
|
+
constant,
|
|
5
|
+
cosine,
|
|
6
|
+
cosine_restarts,
|
|
7
|
+
exponential,
|
|
8
|
+
inverse_sqrt,
|
|
9
|
+
linear,
|
|
10
|
+
multi_step,
|
|
11
|
+
one_cycle,
|
|
12
|
+
polynomial,
|
|
13
|
+
step_decay,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def test_constant() -> None:
|
|
18
|
+
s = constant(lr=0.1)
|
|
19
|
+
assert s(0) == 0.1
|
|
20
|
+
assert s(1000) == 0.1
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_exponential() -> None:
|
|
24
|
+
s = exponential(base_lr=1.0, gamma=0.9)
|
|
25
|
+
assert s(0) == pytest.approx(1.0)
|
|
26
|
+
assert s(1) == pytest.approx(0.9)
|
|
27
|
+
assert s(2) == pytest.approx(0.81)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_step_decay() -> None:
|
|
31
|
+
s = step_decay(base_lr=1.0, gamma=0.5, step_size=10)
|
|
32
|
+
assert s(0) == pytest.approx(1.0)
|
|
33
|
+
assert s(9) == pytest.approx(1.0)
|
|
34
|
+
assert s(10) == pytest.approx(0.5)
|
|
35
|
+
assert s(20) == pytest.approx(0.25)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def test_multi_step() -> None:
|
|
39
|
+
s = multi_step(base_lr=1.0, milestones=[20, 10], gamma=0.1)
|
|
40
|
+
assert s(9) == pytest.approx(1.0)
|
|
41
|
+
assert s(10) == pytest.approx(0.1)
|
|
42
|
+
assert s(19) == pytest.approx(0.1)
|
|
43
|
+
assert s(20) == pytest.approx(0.01)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_linear() -> None:
|
|
47
|
+
s = linear(base_lr=1.0, end_lr=0.0, total_steps=10)
|
|
48
|
+
assert s(0) == pytest.approx(1.0)
|
|
49
|
+
assert s(5) == pytest.approx(0.5)
|
|
50
|
+
assert s(10) == pytest.approx(0.0)
|
|
51
|
+
assert s(50) == pytest.approx(0.0) # held past the end
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_polynomial() -> None:
|
|
55
|
+
s = polynomial(base_lr=1.0, end_lr=0.0, total_steps=10, power=2.0)
|
|
56
|
+
assert s(0) == pytest.approx(1.0)
|
|
57
|
+
assert s(5) == pytest.approx(0.25)
|
|
58
|
+
assert s(10) == pytest.approx(0.0)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_cosine() -> None:
|
|
62
|
+
s = cosine(base_lr=1.0, min_lr=0.0, total_steps=10)
|
|
63
|
+
assert s(0) == pytest.approx(1.0)
|
|
64
|
+
assert s(5) == pytest.approx(0.5)
|
|
65
|
+
assert s(10) == pytest.approx(0.0)
|
|
66
|
+
assert s(100) == pytest.approx(0.0) # held at min past the end
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def test_cosine_restarts_periodic() -> None:
|
|
70
|
+
s = cosine_restarts(base_lr=1.0, min_lr=0.0, period=10, t_mult=1)
|
|
71
|
+
assert s(0) == pytest.approx(1.0)
|
|
72
|
+
assert s(5) == pytest.approx(0.5)
|
|
73
|
+
assert s(10) == pytest.approx(1.0) # restart
|
|
74
|
+
assert s(15) == pytest.approx(0.5)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def test_cosine_restarts_growing() -> None:
|
|
78
|
+
s = cosine_restarts(base_lr=1.0, min_lr=0.0, period=10, t_mult=2)
|
|
79
|
+
assert s(10) == pytest.approx(1.0) # restart, second cycle has length 20
|
|
80
|
+
assert s(20) == pytest.approx(0.5) # halfway through the length-20 cycle
|
|
81
|
+
assert s(30) == pytest.approx(1.0) # restart, third cycle
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_inverse_sqrt() -> None:
|
|
85
|
+
s = inverse_sqrt(base_lr=1.0, warmup_steps=10)
|
|
86
|
+
assert s(0) == pytest.approx(0.0)
|
|
87
|
+
assert s(5) == pytest.approx(0.5)
|
|
88
|
+
assert s(10) == pytest.approx(1.0) # peak at end of warmup
|
|
89
|
+
assert s(40) == pytest.approx(0.5) # sqrt(10/40)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_one_cycle() -> None:
|
|
93
|
+
s = one_cycle(max_lr=1.0, total_steps=100, pct_start=0.3, initial_div=25.0, final_div=10000.0)
|
|
94
|
+
assert s(0) == pytest.approx(0.04) # max_lr / initial_div
|
|
95
|
+
assert s(30) == pytest.approx(1.0) # peak at end of the ramp
|
|
96
|
+
assert s(100) == pytest.approx(0.0001) # max_lr / final_div
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def test_validation() -> None:
|
|
100
|
+
with pytest.raises(ValueError):
|
|
101
|
+
constant(lr=float("inf"))
|
|
102
|
+
with pytest.raises(ValueError):
|
|
103
|
+
cosine(base_lr=1.0, min_lr=0.0, total_steps=0)
|
|
104
|
+
with pytest.raises(ValueError):
|
|
105
|
+
one_cycle(max_lr=1.0, total_steps=100, pct_start=1.5, initial_div=25.0, final_div=1e4)
|
|
106
|
+
with pytest.raises(ValueError):
|
|
107
|
+
constant(lr=0.1)(-1)
|