mushin-py 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mushin_py-0.1.0/LICENSE.txt +28 -0
- mushin_py-0.1.0/PKG-INFO +144 -0
- mushin_py-0.1.0/README.md +106 -0
- mushin_py-0.1.0/pyproject.toml +100 -0
- mushin_py-0.1.0/setup.cfg +4 -0
- mushin_py-0.1.0/src/mushin/__init__.py +25 -0
- mushin_py-0.1.0/src/mushin/_compatibility.py +30 -0
- mushin_py-0.1.0/src/mushin/_utils.py +145 -0
- mushin_py-0.1.0/src/mushin/_validate.py +108 -0
- mushin_py-0.1.0/src/mushin/benchmark/__init__.py +8 -0
- mushin_py-0.1.0/src/mushin/benchmark/_aggregate.py +54 -0
- mushin_py-0.1.0/src/mushin/benchmark/_inference.py +42 -0
- mushin_py-0.1.0/src/mushin/benchmark/_metrics.py +48 -0
- mushin_py-0.1.0/src/mushin/benchmark/_predict.py +21 -0
- mushin_py-0.1.0/src/mushin/benchmark/_result.py +72 -0
- mushin_py-0.1.0/src/mushin/benchmark/_stats.py +122 -0
- mushin_py-0.1.0/src/mushin/benchmark/compare.py +73 -0
- mushin_py-0.1.0/src/mushin/lightning/__init__.py +8 -0
- mushin_py-0.1.0/src/mushin/lightning/_pl_main.py +44 -0
- mushin_py-0.1.0/src/mushin/lightning/callbacks.py +80 -0
- mushin_py-0.1.0/src/mushin/lightning/launchers.py +384 -0
- mushin_py-0.1.0/src/mushin/testing/__init__.py +3 -0
- mushin_py-0.1.0/src/mushin/testing/lightning.py +95 -0
- mushin_py-0.1.0/src/mushin/workflows.py +1203 -0
- mushin_py-0.1.0/src/mushin_py.egg-info/PKG-INFO +144 -0
- mushin_py-0.1.0/src/mushin_py.egg-info/SOURCES.txt +31 -0
- mushin_py-0.1.0/src/mushin_py.egg-info/dependency_links.txt +1 -0
- mushin_py-0.1.0/src/mushin_py.egg-info/requires.txt +23 -0
- mushin_py-0.1.0/src/mushin_py.egg-info/top_level.txt +1 -0
- mushin_py-0.1.0/tests/test_examples.py +26 -0
- mushin_py-0.1.0/tests/test_lightning_callbacks.py +41 -0
- mushin_py-0.1.0/tests/test_lightning_hydra_ddp.py +234 -0
- mushin_py-0.1.0/tests/test_workflows.py +677 -0
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023 Massachusetts Institute of Technology
|
|
4
|
+
Copyright (c) 2026 mushin contributors
|
|
5
|
+
|
|
6
|
+
This package ("mushin") is derived from the `rai_toolbox.mushin` subpackage of
|
|
7
|
+
MIT Lincoln Laboratory's responsible-ai-toolbox
|
|
8
|
+
(https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox), which is
|
|
9
|
+
distributed under the MIT License. The original copyright notice is retained
|
|
10
|
+
above.
|
|
11
|
+
|
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
14
|
+
in the Software without restriction, including without limitation the rights
|
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
17
|
+
furnished to do so, subject to the following conditions:
|
|
18
|
+
|
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
|
20
|
+
copies or substantial portions of the Software.
|
|
21
|
+
|
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
+
SOFTWARE.
|
mushin_py-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: mushin-py
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Boilerplate-free, reproducible ML experiment workflows built on PyTorch Lightning and hydra-zen. Carved out of MIT-LL's responsible-ai-toolbox.
|
|
5
|
+
Author: Massachusetts Institute of Technology (original rai_toolbox.mushin authors)
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Upstream (origin), https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox
|
|
8
|
+
Keywords: machine learning,reproducibility,pytorch-lightning,hydra,hydra-zen
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Requires-Python: >=3.9
|
|
18
|
+
Description-Content-Type: text/markdown
|
|
19
|
+
License-File: LICENSE.txt
|
|
20
|
+
Requires-Dist: pytorch-lightning>=1.5.0
|
|
21
|
+
Requires-Dist: hydra-zen>=0.9.0
|
|
22
|
+
Requires-Dist: hydra-core>=1.2.0
|
|
23
|
+
Requires-Dist: torch<2.3,>=1.13.0; sys_platform == "darwin" and platform_machine == "x86_64"
|
|
24
|
+
Requires-Dist: torch>=1.13.0; sys_platform != "darwin" or platform_machine != "x86_64"
|
|
25
|
+
Requires-Dist: numpy<2,>=1.24; sys_platform == "darwin" and platform_machine == "x86_64"
|
|
26
|
+
Requires-Dist: numpy>=2; sys_platform != "darwin" or platform_machine != "x86_64"
|
|
27
|
+
Requires-Dist: omegaconf>=2.1.1
|
|
28
|
+
Requires-Dist: xarray>=0.19.0
|
|
29
|
+
Requires-Dist: typing-extensions>=4.1.0
|
|
30
|
+
Requires-Dist: torchmetrics>=1.0
|
|
31
|
+
Requires-Dist: scipy>=1.10
|
|
32
|
+
Requires-Dist: pandas>=1.5
|
|
33
|
+
Provides-Extra: viz
|
|
34
|
+
Requires-Dist: matplotlib>=3.3; extra == "viz"
|
|
35
|
+
Provides-Extra: netcdf
|
|
36
|
+
Requires-Dist: netCDF4>=1.5.8; extra == "netcdf"
|
|
37
|
+
Dynamic: license-file
|
|
38
|
+
|
|
39
|
+
<p align="center">
|
|
40
|
+
<picture>
|
|
41
|
+
<source media="(prefers-color-scheme: dark)" srcset="logos/mushin-dark.png">
|
|
42
|
+
<img src="logos/mushin-light.png" alt="mushin logo" width="200">
|
|
43
|
+
</picture>
|
|
44
|
+
</p>
|
|
45
|
+
|
|
46
|
+
<h1 align="center">mushin</h1>
|
|
47
|
+
|
|
48
|
+
[](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml)
|
|
49
|
+
[](https://pypi.org/project/mushin-py/)
|
|
50
|
+
[](https://pypi.org/project/mushin-py/)
|
|
51
|
+
[](LICENSE.txt)
|
|
52
|
+
|
|
53
|
+
Boilerplate-free, reproducible machine-learning experiment workflows built on
|
|
54
|
+
[PyTorch Lightning](https://lightning.ai/) and
|
|
55
|
+
[hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen).
|
|
56
|
+
|
|
57
|
+
`mushin` is a standalone carve-out of the `rai_toolbox.mushin` subpackage from
|
|
58
|
+
MIT Lincoln Laboratory's
|
|
59
|
+
[responsible-ai-toolbox](https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox).
|
|
60
|
+
The upstream toolbox is no longer maintained (last release May 2023), but the
|
|
61
|
+
`mushin` workflow layer still works against current versions of its
|
|
62
|
+
dependencies. This package extracts just that layer so it can be maintained and
|
|
63
|
+
used on its own.
|
|
64
|
+
|
|
65
|
+
## Quickstart: run a sweep, get a dataset
|
|
66
|
+
|
|
67
|
+
Define your experiment as a function, sweep over parameters, and get the results
|
|
68
|
+
back as a labeled `xarray.Dataset` — not rows in a dashboard you have to export.
|
|
69
|
+
|
|
70
|
+
```python
|
|
71
|
+
import torch as tr
|
|
72
|
+
from mushin import multirun
|
|
73
|
+
from mushin.workflows import MultiRunMetricsWorkflow
|
|
74
|
+
|
|
75
|
+
class LRSweep(MultiRunMetricsWorkflow):
|
|
76
|
+
@staticmethod
|
|
77
|
+
def task(lr: float, seed: int) -> dict:
|
|
78
|
+
tr.manual_seed(seed)
|
|
79
|
+
# ... train a model with this lr/seed ...
|
|
80
|
+
return dict(accuracy=acc) # whatever you return becomes a data variable
|
|
81
|
+
|
|
82
|
+
wf = LRSweep()
|
|
83
|
+
wf.run(lr=multirun([0.01, 0.1, 1.0]), seed=multirun([0, 1, 2])) # 9 runs
|
|
84
|
+
|
|
85
|
+
ds = wf.to_xarray()
|
|
86
|
+
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
|
|
87
|
+
# Data variables: accuracy (lr, seed)
|
|
88
|
+
|
|
89
|
+
ds["accuracy"].mean("seed") # average over seeds, per learning rate
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
The full runnable version is in [`examples/sweep_to_dataset.py`](examples/sweep_to_dataset.py):
|
|
93
|
+
|
|
94
|
+
```bash
|
|
95
|
+
uv run python examples/sweep_to_dataset.py
|
|
96
|
+
```
|
|
97
|
+
|
|
98
|
+
## What it provides
|
|
99
|
+
|
|
100
|
+
- `BaseWorkflow`, `MultiRunMetricsWorkflow`, `RobustnessCurve` — declarative,
|
|
101
|
+
reproducible experiment workflows that record configs, checkpoints, and
|
|
102
|
+
metrics, and load results back as labeled `xarray` datasets.
|
|
103
|
+
- `MetricsCallback` — a Lightning callback for capturing metrics.
|
|
104
|
+
- `HydraDDP` — a Hydra/Lightning strategy for multi-GPU (DDP) launches.
|
|
105
|
+
- `multirun`, `hydra_list`, `load_experiment`, `load_from_checkpoint` — helpers.
|
|
106
|
+
|
|
107
|
+
## Install
|
|
108
|
+
|
|
109
|
+
From PyPI (the distribution is named `mushin-py`; you still `import mushin`):
|
|
110
|
+
|
|
111
|
+
```bash
|
|
112
|
+
pip install mushin-py
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
For a development environment (runtime deps + dev tooling), this project uses
|
|
116
|
+
[uv](https://docs.astral.sh/uv/):
|
|
117
|
+
|
|
118
|
+
```bash
|
|
119
|
+
uv sync
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
Optional runtime extras: `viz` (matplotlib, for `RobustnessCurve` plotting) and
|
|
123
|
+
`netcdf` (netCDF4) — e.g. `pip install "mushin-py[viz]"`.
|
|
124
|
+
|
|
125
|
+
## Develop
|
|
126
|
+
|
|
127
|
+
```bash
|
|
128
|
+
uv run pytest tests/ --hypothesis-profile fast # tests (DDP test needs >=2 GPUs)
|
|
129
|
+
uv run ruff check . # lint
|
|
130
|
+
uv run ruff format . # format
|
|
131
|
+
uv run codespell src tests # spell check
|
|
132
|
+
```
|
|
133
|
+
|
|
134
|
+
Or use the `make` shortcuts (`make help` to list them): `make check` runs
|
|
135
|
+
lint + format-check + spell + tests (what CI runs); `make test-py PYTHON=3.12`
|
|
136
|
+
runs the suite on a specific Python version.
|
|
137
|
+
|
|
138
|
+
Supported Python versions: 3.9 – 3.14.
|
|
139
|
+
|
|
140
|
+
## Relationship to upstream
|
|
141
|
+
|
|
142
|
+
This is a fork/extraction, not a replacement endorsed by MIT-LL. The configuration
|
|
143
|
+
engine it depends on, `hydra-zen`, is actively maintained by the same group. See
|
|
144
|
+
`LICENSE.txt` for attribution; the original MIT copyright is retained.
|
|
@@ -0,0 +1,106 @@
|
|
|
1
|
+
<p align="center">
|
|
2
|
+
<picture>
|
|
3
|
+
<source media="(prefers-color-scheme: dark)" srcset="logos/mushin-dark.png">
|
|
4
|
+
<img src="logos/mushin-light.png" alt="mushin logo" width="200">
|
|
5
|
+
</picture>
|
|
6
|
+
</p>
|
|
7
|
+
|
|
8
|
+
<h1 align="center">mushin</h1>
|
|
9
|
+
|
|
10
|
+
[](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml)
|
|
11
|
+
[](https://pypi.org/project/mushin-py/)
|
|
12
|
+
[](https://pypi.org/project/mushin-py/)
|
|
13
|
+
[](LICENSE.txt)
|
|
14
|
+
|
|
15
|
+
Boilerplate-free, reproducible machine-learning experiment workflows built on
|
|
16
|
+
[PyTorch Lightning](https://lightning.ai/) and
|
|
17
|
+
[hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen).
|
|
18
|
+
|
|
19
|
+
`mushin` is a standalone carve-out of the `rai_toolbox.mushin` subpackage from
|
|
20
|
+
MIT Lincoln Laboratory's
|
|
21
|
+
[responsible-ai-toolbox](https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox).
|
|
22
|
+
The upstream toolbox is no longer maintained (last release May 2023), but the
|
|
23
|
+
`mushin` workflow layer still works against current versions of its
|
|
24
|
+
dependencies. This package extracts just that layer so it can be maintained and
|
|
25
|
+
used on its own.
|
|
26
|
+
|
|
27
|
+
## Quickstart: run a sweep, get a dataset
|
|
28
|
+
|
|
29
|
+
Define your experiment as a function, sweep over parameters, and get the results
|
|
30
|
+
back as a labeled `xarray.Dataset` — not rows in a dashboard you have to export.
|
|
31
|
+
|
|
32
|
+
```python
|
|
33
|
+
import torch as tr
|
|
34
|
+
from mushin import multirun
|
|
35
|
+
from mushin.workflows import MultiRunMetricsWorkflow
|
|
36
|
+
|
|
37
|
+
class LRSweep(MultiRunMetricsWorkflow):
|
|
38
|
+
@staticmethod
|
|
39
|
+
def task(lr: float, seed: int) -> dict:
|
|
40
|
+
tr.manual_seed(seed)
|
|
41
|
+
# ... train a model with this lr/seed ...
|
|
42
|
+
return dict(accuracy=acc) # whatever you return becomes a data variable
|
|
43
|
+
|
|
44
|
+
wf = LRSweep()
|
|
45
|
+
wf.run(lr=multirun([0.01, 0.1, 1.0]), seed=multirun([0, 1, 2])) # 9 runs
|
|
46
|
+
|
|
47
|
+
ds = wf.to_xarray()
|
|
48
|
+
# <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
|
|
49
|
+
# Data variables: accuracy (lr, seed)
|
|
50
|
+
|
|
51
|
+
ds["accuracy"].mean("seed") # average over seeds, per learning rate
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
The full runnable version is in [`examples/sweep_to_dataset.py`](examples/sweep_to_dataset.py):
|
|
55
|
+
|
|
56
|
+
```bash
|
|
57
|
+
uv run python examples/sweep_to_dataset.py
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
## What it provides
|
|
61
|
+
|
|
62
|
+
- `BaseWorkflow`, `MultiRunMetricsWorkflow`, `RobustnessCurve` — declarative,
|
|
63
|
+
reproducible experiment workflows that record configs, checkpoints, and
|
|
64
|
+
metrics, and load results back as labeled `xarray` datasets.
|
|
65
|
+
- `MetricsCallback` — a Lightning callback for capturing metrics.
|
|
66
|
+
- `HydraDDP` — a Hydra/Lightning strategy for multi-GPU (DDP) launches.
|
|
67
|
+
- `multirun`, `hydra_list`, `load_experiment`, `load_from_checkpoint` — helpers.
|
|
68
|
+
|
|
69
|
+
## Install
|
|
70
|
+
|
|
71
|
+
From PyPI (the distribution is named `mushin-py`; you still `import mushin`):
|
|
72
|
+
|
|
73
|
+
```bash
|
|
74
|
+
pip install mushin-py
|
|
75
|
+
```
|
|
76
|
+
|
|
77
|
+
For a development environment (runtime deps + dev tooling), this project uses
|
|
78
|
+
[uv](https://docs.astral.sh/uv/):
|
|
79
|
+
|
|
80
|
+
```bash
|
|
81
|
+
uv sync
|
|
82
|
+
```
|
|
83
|
+
|
|
84
|
+
Optional runtime extras: `viz` (matplotlib, for `RobustnessCurve` plotting) and
|
|
85
|
+
`netcdf` (netCDF4) — e.g. `pip install "mushin-py[viz]"`.
|
|
86
|
+
|
|
87
|
+
## Develop
|
|
88
|
+
|
|
89
|
+
```bash
|
|
90
|
+
uv run pytest tests/ --hypothesis-profile fast # tests (DDP test needs >=2 GPUs)
|
|
91
|
+
uv run ruff check . # lint
|
|
92
|
+
uv run ruff format . # format
|
|
93
|
+
uv run codespell src tests # spell check
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
Or use the `make` shortcuts (`make help` to list them): `make check` runs
|
|
97
|
+
lint + format-check + spell + tests (what CI runs); `make test-py PYTHON=3.12`
|
|
98
|
+
runs the suite on a specific Python version.
|
|
99
|
+
|
|
100
|
+
Supported Python versions: 3.9 – 3.14.
|
|
101
|
+
|
|
102
|
+
## Relationship to upstream
|
|
103
|
+
|
|
104
|
+
This is a fork/extraction, not a replacement endorsed by MIT-LL. The configuration
|
|
105
|
+
engine it depends on, `hydra-zen`, is actively maintained by the same group. See
|
|
106
|
+
`LICENSE.txt` for attribution; the original MIT copyright is retained.
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=82.0.1", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
# Distribution name on PyPI is "mushin-py" (the "mushin" name was taken);
|
|
7
|
+
# the import package is still `import mushin` (see src/mushin/).
|
|
8
|
+
name = "mushin-py"
|
|
9
|
+
version = "0.1.0"
|
|
10
|
+
description = "Boilerplate-free, reproducible ML experiment workflows built on PyTorch Lightning and hydra-zen. Carved out of MIT-LL's responsible-ai-toolbox."
|
|
11
|
+
readme = "README.md"
|
|
12
|
+
requires-python = ">=3.9"
|
|
13
|
+
license = { text = "MIT" }
|
|
14
|
+
authors = [
|
|
15
|
+
{ name = "Massachusetts Institute of Technology (original rai_toolbox.mushin authors)" },
|
|
16
|
+
]
|
|
17
|
+
keywords = ["machine learning", "reproducibility", "pytorch-lightning", "hydra", "hydra-zen"]
|
|
18
|
+
classifiers = [
|
|
19
|
+
"Programming Language :: Python :: 3.9",
|
|
20
|
+
"Programming Language :: Python :: 3.10",
|
|
21
|
+
"Programming Language :: Python :: 3.11",
|
|
22
|
+
"Programming Language :: Python :: 3.12",
|
|
23
|
+
"Programming Language :: Python :: 3.13",
|
|
24
|
+
"Programming Language :: Python :: 3.14",
|
|
25
|
+
"License :: OSI Approved :: MIT License",
|
|
26
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
27
|
+
]
|
|
28
|
+
dependencies = [
|
|
29
|
+
"pytorch-lightning >= 1.5.0",
|
|
30
|
+
"hydra-zen >= 0.9.0",
|
|
31
|
+
"hydra-core >= 1.2.0",
|
|
32
|
+
# torch dropped Intel-macOS (x86_64) wheels after 2.2.x, so cap there on that
|
|
33
|
+
# platform only; everywhere else take any modern torch.
|
|
34
|
+
"torch >= 1.13.0,<2.3 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
|
|
35
|
+
"torch >= 1.13.0 ; sys_platform != 'darwin' or platform_machine != 'x86_64'",
|
|
36
|
+
# torch 2.2.x (the Intel-macOS cap above) is built against NumPy 1.x; that
|
|
37
|
+
# platform is inherently Python <=3.11. Everywhere else take NumPy 2.x.
|
|
38
|
+
"numpy >= 1.24,<2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
|
|
39
|
+
"numpy >= 2 ; sys_platform != 'darwin' or platform_machine != 'x86_64'",
|
|
40
|
+
"omegaconf >= 2.1.1",
|
|
41
|
+
"xarray >= 0.19.0",
|
|
42
|
+
"typing-extensions >= 4.1.0",
|
|
43
|
+
"torchmetrics >= 1.0",
|
|
44
|
+
"scipy >= 1.10",
|
|
45
|
+
"pandas >= 1.5",
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
[project.optional-dependencies]
|
|
49
|
+
viz = ["matplotlib >= 3.3"]
|
|
50
|
+
netcdf = ["netCDF4 >= 1.5.8"]
|
|
51
|
+
|
|
52
|
+
[project.urls]
|
|
53
|
+
"Upstream (origin)" = "https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox"
|
|
54
|
+
|
|
55
|
+
# Dev tooling, installed by `uv sync` (PEP 735 dependency group).
|
|
56
|
+
[dependency-groups]
|
|
57
|
+
dev = [
|
|
58
|
+
"ruff >= 0.6",
|
|
59
|
+
"codespell >= 2.3",
|
|
60
|
+
"pre-commit >= 3.5",
|
|
61
|
+
"pytest >= 7.0.0",
|
|
62
|
+
"hypothesis >= 6.28.0",
|
|
63
|
+
"pytest-xdist",
|
|
64
|
+
"matplotlib >= 3.3",
|
|
65
|
+
"netCDF4 >= 1.5.8",
|
|
66
|
+
]
|
|
67
|
+
|
|
68
|
+
[tool.setuptools.packages.find]
|
|
69
|
+
where = ["src"]
|
|
70
|
+
|
|
71
|
+
[tool.pytest.ini_options]
|
|
72
|
+
testpaths = ["tests"]
|
|
73
|
+
pythonpath = ["examples"]
|
|
74
|
+
filterwarnings = [
|
|
75
|
+
# mushin deliberately relies on Hydra's chdir-into-job-dir behavior to read
|
|
76
|
+
# results back; this deprecation is about that behavior changing upstream.
|
|
77
|
+
"ignore:Future Hydra versions will no longer change working directory:UserWarning",
|
|
78
|
+
# matplotlib internals call deprecated pyparsing APIs; not fixable from here.
|
|
79
|
+
"ignore:.* deprecated - use .*",
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
[tool.ruff]
|
|
83
|
+
line-length = 88
|
|
84
|
+
target-version = "py39"
|
|
85
|
+
src = ["src", "tests"]
|
|
86
|
+
|
|
87
|
+
[tool.ruff.lint]
|
|
88
|
+
# E/W: pycodestyle, F: pyflakes, I: isort, UP: pyupgrade, B: bugbear
|
|
89
|
+
select = ["E", "W", "F", "I", "UP", "B"]
|
|
90
|
+
ignore = [
|
|
91
|
+
"E501", # line length is handled by the formatter
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
[tool.ruff.lint.per-file-ignores]
|
|
95
|
+
"tests/*" = ["B"] # bugbear is noisy for test fixtures/assertions
|
|
96
|
+
|
|
97
|
+
[tool.codespell]
|
|
98
|
+
skip = "*.lock,*.ckpt,*.pt,.git,.venv,*.egg-info"
|
|
99
|
+
# "mushin" (無心) and a few ML terms trip the default dictionary.
|
|
100
|
+
ignore-words-list = "mushin,nin,tha"
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
|
|
2
|
+
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
|
|
3
|
+
# SPDX-License-Identifier: MIT
|
|
4
|
+
|
|
5
|
+
from ._utils import load_experiment, load_from_checkpoint
|
|
6
|
+
from .lightning import HydraDDP, MetricsCallback
|
|
7
|
+
from .workflows import (
|
|
8
|
+
BaseWorkflow,
|
|
9
|
+
MultiRunMetricsWorkflow,
|
|
10
|
+
RobustnessCurve,
|
|
11
|
+
hydra_list,
|
|
12
|
+
multirun,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"load_experiment",
|
|
17
|
+
"load_from_checkpoint",
|
|
18
|
+
"MetricsCallback",
|
|
19
|
+
"MultiRunMetricsWorkflow",
|
|
20
|
+
"HydraDDP",
|
|
21
|
+
"RobustnessCurve",
|
|
22
|
+
"BaseWorkflow",
|
|
23
|
+
"multirun",
|
|
24
|
+
"hydra_list",
|
|
25
|
+
]
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
|
|
2
|
+
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
|
|
3
|
+
# SPDX-License-Identifier: MIT
|
|
4
|
+
|
|
5
|
+
from typing import Final, NamedTuple
|
|
6
|
+
|
|
7
|
+
import pytorch_lightning
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Version(NamedTuple):
|
|
11
|
+
major: int
|
|
12
|
+
minor: int
|
|
13
|
+
patch: int
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_version(ver_str: str) -> Version:
|
|
17
|
+
# Not for general use. Tested only for Hydra and OmegaConf
|
|
18
|
+
# version string styles
|
|
19
|
+
|
|
20
|
+
splits = ver_str.split(".")[:3]
|
|
21
|
+
if not len(splits) == 3: # pragma: no cover
|
|
22
|
+
raise ValueError(f"Version string {ver_str} couldn't be parsed")
|
|
23
|
+
|
|
24
|
+
major, minor = (int(v) for v in splits[:2])
|
|
25
|
+
patch_str, *_ = splits[-1].split("rc")
|
|
26
|
+
|
|
27
|
+
return Version(major=major, minor=minor, patch=int(patch_str))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
PL_VERSION: Final = _get_version(pytorch_lightning.__version__)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
# Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
|
|
2
|
+
# Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
|
|
3
|
+
# SPDX-License-Identifier: MIT
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Optional, Union
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from hydra_zen import load_from_yaml
|
|
12
|
+
from omegaconf import DictConfig, ListConfig
|
|
13
|
+
from torch import nn
|
|
14
|
+
|
|
15
|
+
log = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def load_from_checkpoint(
|
|
19
|
+
model: nn.Module,
|
|
20
|
+
*,
|
|
21
|
+
ckpt: Optional[Union[str, Path]] = None,
|
|
22
|
+
weights_key: Optional[str] = None,
|
|
23
|
+
weights_key_strip: Optional[str] = None,
|
|
24
|
+
model_attr: Optional[str] = None,
|
|
25
|
+
) -> nn.Module:
|
|
26
|
+
"""Load model weights.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
model : Module
|
|
31
|
+
The PyTorch Module
|
|
32
|
+
|
|
33
|
+
ckpt : Optional[Union[str, Path]]
|
|
34
|
+
The path to the file containing the model weights. If no path is provided
|
|
35
|
+
the model will not be updated.
|
|
36
|
+
|
|
37
|
+
weights_key : Optional[str] (default: "state_dict")
|
|
38
|
+
(load_module=False) The key from the checkpoint file containing the model
|
|
39
|
+
weights.
|
|
40
|
+
|
|
41
|
+
weights_key_strip : Optional[str] (default: "model")
|
|
42
|
+
(load_module=False) The prefix to remove from each weight's key prior
|
|
43
|
+
to loading the module.
|
|
44
|
+
|
|
45
|
+
model_attr : Optional[str] (default: "model")
|
|
46
|
+
(load_module=False) The attribute of the module containing the `torch.nn.Module`
|
|
47
|
+
|
|
48
|
+
Returns
|
|
49
|
+
-------
|
|
50
|
+
module : LightningModule
|
|
51
|
+
"""
|
|
52
|
+
if ckpt is None:
|
|
53
|
+
return model
|
|
54
|
+
|
|
55
|
+
ckpt = Path(str(ckpt))
|
|
56
|
+
if not ckpt.exists():
|
|
57
|
+
ckpt = Path.home() / ".torch" / "models" / ckpt
|
|
58
|
+
log.info(f"Loading model checkpoint from {ckpt}")
|
|
59
|
+
|
|
60
|
+
# weights_only=False: these are trusted, self-produced checkpoints that may
|
|
61
|
+
# hold more than tensors. torch 2.6 flipped this default to True.
|
|
62
|
+
ckpt_data: dict[str, Any] = torch.load(ckpt, map_location="cpu", weights_only=False)
|
|
63
|
+
|
|
64
|
+
if weights_key is not None:
|
|
65
|
+
assert weights_key in ckpt_data
|
|
66
|
+
ckpt_data = ckpt_data[weights_key]
|
|
67
|
+
|
|
68
|
+
if weights_key_strip:
|
|
69
|
+
if not weights_key_strip.endswith("."):
|
|
70
|
+
weights_key_strip = weights_key_strip + "."
|
|
71
|
+
|
|
72
|
+
ckpt_data = {
|
|
73
|
+
k[len(weights_key_strip) :]: v
|
|
74
|
+
for k, v in ckpt_data.items()
|
|
75
|
+
if k.startswith(weights_key_strip)
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
if model_attr is None:
|
|
79
|
+
# The weights can be loaded in directly
|
|
80
|
+
model.load_state_dict(ckpt_data)
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
assert hasattr(model, model_attr)
|
|
84
|
+
getattr(model, model_attr).load_state_dict(ckpt_data)
|
|
85
|
+
|
|
86
|
+
return model
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class Experiment:
|
|
91
|
+
working_dir: str
|
|
92
|
+
cfg: Optional[Union[dict, ListConfig, DictConfig]]
|
|
93
|
+
ckpts: list[str]
|
|
94
|
+
metrics: dict
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def load_experiment(
|
|
98
|
+
exp_path: Union[str, Path], search_path: Optional[Union[str, Path]] = None
|
|
99
|
+
) -> Union[Experiment, list[Experiment]]:
|
|
100
|
+
"""Loads all configuration and metrics outputs in an experiment directory.
|
|
101
|
+
|
|
102
|
+
Parameters
|
|
103
|
+
----------
|
|
104
|
+
exp_path: Union[str, Path]
|
|
105
|
+
The directory to search for data. Directory must include the
|
|
106
|
+
".hydra/config.yaml" file.
|
|
107
|
+
|
|
108
|
+
Returns
|
|
109
|
+
----------
|
|
110
|
+
exps: Union[Experiment, List[Experiment]]
|
|
111
|
+
|
|
112
|
+
"""
|
|
113
|
+
assert Path(exp_path).exists(), f"{exp_path} not found"
|
|
114
|
+
|
|
115
|
+
# first find all .hydra files
|
|
116
|
+
if search_path is None:
|
|
117
|
+
search_path = ".hydra"
|
|
118
|
+
cfg_files = sorted(Path(exp_path).absolute().glob(f"**/{str(search_path)}"))
|
|
119
|
+
|
|
120
|
+
# For each file load metrics data
|
|
121
|
+
exps = []
|
|
122
|
+
for path in cfg_files:
|
|
123
|
+
# Save experiment configuration
|
|
124
|
+
cfg_files = list(path.parent.glob("**/config.yaml"))
|
|
125
|
+
cfg = None
|
|
126
|
+
if len(cfg_files) == 1:
|
|
127
|
+
cfg = load_from_yaml(cfg_files[0])
|
|
128
|
+
|
|
129
|
+
# Load metrics files
|
|
130
|
+
files = path.parent.glob("*.pt")
|
|
131
|
+
metrics = dict()
|
|
132
|
+
for f in files:
|
|
133
|
+
name = f.name
|
|
134
|
+
metrics[name[:-3]] = torch.load(f, weights_only=False)
|
|
135
|
+
|
|
136
|
+
# Load path to checkpoints
|
|
137
|
+
ckpts = [str(ckpt.resolve()) for ckpt in path.parent.glob("**/*.ckpt")]
|
|
138
|
+
|
|
139
|
+
# Append experiment to list
|
|
140
|
+
exps.append(Experiment(str(path.parent.parent), cfg, ckpts, metrics))
|
|
141
|
+
|
|
142
|
+
if len(exps) == 1:
|
|
143
|
+
return exps[0]
|
|
144
|
+
|
|
145
|
+
return exps
|