mushin-py 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. mushin_py-0.1.0/LICENSE.txt +28 -0
  2. mushin_py-0.1.0/PKG-INFO +144 -0
  3. mushin_py-0.1.0/README.md +106 -0
  4. mushin_py-0.1.0/pyproject.toml +100 -0
  5. mushin_py-0.1.0/setup.cfg +4 -0
  6. mushin_py-0.1.0/src/mushin/__init__.py +25 -0
  7. mushin_py-0.1.0/src/mushin/_compatibility.py +30 -0
  8. mushin_py-0.1.0/src/mushin/_utils.py +145 -0
  9. mushin_py-0.1.0/src/mushin/_validate.py +108 -0
  10. mushin_py-0.1.0/src/mushin/benchmark/__init__.py +8 -0
  11. mushin_py-0.1.0/src/mushin/benchmark/_aggregate.py +54 -0
  12. mushin_py-0.1.0/src/mushin/benchmark/_inference.py +42 -0
  13. mushin_py-0.1.0/src/mushin/benchmark/_metrics.py +48 -0
  14. mushin_py-0.1.0/src/mushin/benchmark/_predict.py +21 -0
  15. mushin_py-0.1.0/src/mushin/benchmark/_result.py +72 -0
  16. mushin_py-0.1.0/src/mushin/benchmark/_stats.py +122 -0
  17. mushin_py-0.1.0/src/mushin/benchmark/compare.py +73 -0
  18. mushin_py-0.1.0/src/mushin/lightning/__init__.py +8 -0
  19. mushin_py-0.1.0/src/mushin/lightning/_pl_main.py +44 -0
  20. mushin_py-0.1.0/src/mushin/lightning/callbacks.py +80 -0
  21. mushin_py-0.1.0/src/mushin/lightning/launchers.py +384 -0
  22. mushin_py-0.1.0/src/mushin/testing/__init__.py +3 -0
  23. mushin_py-0.1.0/src/mushin/testing/lightning.py +95 -0
  24. mushin_py-0.1.0/src/mushin/workflows.py +1203 -0
  25. mushin_py-0.1.0/src/mushin_py.egg-info/PKG-INFO +144 -0
  26. mushin_py-0.1.0/src/mushin_py.egg-info/SOURCES.txt +31 -0
  27. mushin_py-0.1.0/src/mushin_py.egg-info/dependency_links.txt +1 -0
  28. mushin_py-0.1.0/src/mushin_py.egg-info/requires.txt +23 -0
  29. mushin_py-0.1.0/src/mushin_py.egg-info/top_level.txt +1 -0
  30. mushin_py-0.1.0/tests/test_examples.py +26 -0
  31. mushin_py-0.1.0/tests/test_lightning_callbacks.py +41 -0
  32. mushin_py-0.1.0/tests/test_lightning_hydra_ddp.py +234 -0
  33. mushin_py-0.1.0/tests/test_workflows.py +677 -0
@@ -0,0 +1,28 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Massachusetts Institute of Technology
4
+ Copyright (c) 2026 mushin contributors
5
+
6
+ This package ("mushin") is derived from the `rai_toolbox.mushin` subpackage of
7
+ MIT Lincoln Laboratory's responsible-ai-toolbox
8
+ (https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox), which is
9
+ distributed under the MIT License. The original copyright notice is retained
10
+ above.
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
@@ -0,0 +1,144 @@
1
+ Metadata-Version: 2.4
2
+ Name: mushin-py
3
+ Version: 0.1.0
4
+ Summary: Boilerplate-free, reproducible ML experiment workflows built on PyTorch Lightning and hydra-zen. Carved out of MIT-LL's responsible-ai-toolbox.
5
+ Author: Massachusetts Institute of Technology (original rai_toolbox.mushin authors)
6
+ License: MIT
7
+ Project-URL: Upstream (origin), https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox
8
+ Keywords: machine learning,reproducibility,pytorch-lightning,hydra,hydra-zen
9
+ Classifier: Programming Language :: Python :: 3.9
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Classifier: Programming Language :: Python :: 3.13
14
+ Classifier: Programming Language :: Python :: 3.14
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Requires-Python: >=3.9
18
+ Description-Content-Type: text/markdown
19
+ License-File: LICENSE.txt
20
+ Requires-Dist: pytorch-lightning>=1.5.0
21
+ Requires-Dist: hydra-zen>=0.9.0
22
+ Requires-Dist: hydra-core>=1.2.0
23
+ Requires-Dist: torch<2.3,>=1.13.0; sys_platform == "darwin" and platform_machine == "x86_64"
24
+ Requires-Dist: torch>=1.13.0; sys_platform != "darwin" or platform_machine != "x86_64"
25
+ Requires-Dist: numpy<2,>=1.24; sys_platform == "darwin" and platform_machine == "x86_64"
26
+ Requires-Dist: numpy>=2; sys_platform != "darwin" or platform_machine != "x86_64"
27
+ Requires-Dist: omegaconf>=2.1.1
28
+ Requires-Dist: xarray>=0.19.0
29
+ Requires-Dist: typing-extensions>=4.1.0
30
+ Requires-Dist: torchmetrics>=1.0
31
+ Requires-Dist: scipy>=1.10
32
+ Requires-Dist: pandas>=1.5
33
+ Provides-Extra: viz
34
+ Requires-Dist: matplotlib>=3.3; extra == "viz"
35
+ Provides-Extra: netcdf
36
+ Requires-Dist: netCDF4>=1.5.8; extra == "netcdf"
37
+ Dynamic: license-file
38
+
39
+ <p align="center">
40
+ <picture>
41
+ <source media="(prefers-color-scheme: dark)" srcset="logos/mushin-dark.png">
42
+ <img src="logos/mushin-light.png" alt="mushin logo" width="200">
43
+ </picture>
44
+ </p>
45
+
46
+ <h1 align="center">mushin</h1>
47
+
48
+ [![CI](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml/badge.svg)](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml)
49
+ [![PyPI](https://img.shields.io/pypi/v/mushin-py.svg)](https://pypi.org/project/mushin-py/)
50
+ [![Python versions](https://img.shields.io/pypi/pyversions/mushin-py.svg)](https://pypi.org/project/mushin-py/)
51
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE.txt)
52
+
53
+ Boilerplate-free, reproducible machine-learning experiment workflows built on
54
+ [PyTorch Lightning](https://lightning.ai/) and
55
+ [hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen).
56
+
57
+ `mushin` is a standalone carve-out of the `rai_toolbox.mushin` subpackage from
58
+ MIT Lincoln Laboratory's
59
+ [responsible-ai-toolbox](https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox).
60
+ The upstream toolbox is no longer maintained (last release May 2023), but the
61
+ `mushin` workflow layer still works against current versions of its
62
+ dependencies. This package extracts just that layer so it can be maintained and
63
+ used on its own.
64
+
65
+ ## Quickstart: run a sweep, get a dataset
66
+
67
+ Define your experiment as a function, sweep over parameters, and get the results
68
+ back as a labeled `xarray.Dataset` — not rows in a dashboard you have to export.
69
+
70
+ ```python
71
+ import torch as tr
72
+ from mushin import multirun
73
+ from mushin.workflows import MultiRunMetricsWorkflow
74
+
75
+ class LRSweep(MultiRunMetricsWorkflow):
76
+ @staticmethod
77
+ def task(lr: float, seed: int) -> dict:
78
+ tr.manual_seed(seed)
79
+ # ... train a model with this lr/seed ...
80
+ return dict(accuracy=acc) # whatever you return becomes a data variable
81
+
82
+ wf = LRSweep()
83
+ wf.run(lr=multirun([0.01, 0.1, 1.0]), seed=multirun([0, 1, 2])) # 9 runs
84
+
85
+ ds = wf.to_xarray()
86
+ # <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
87
+ # Data variables: accuracy (lr, seed)
88
+
89
+ ds["accuracy"].mean("seed") # average over seeds, per learning rate
90
+ ```
91
+
92
+ The full runnable version is in [`examples/sweep_to_dataset.py`](examples/sweep_to_dataset.py):
93
+
94
+ ```bash
95
+ uv run python examples/sweep_to_dataset.py
96
+ ```
97
+
98
+ ## What it provides
99
+
100
+ - `BaseWorkflow`, `MultiRunMetricsWorkflow`, `RobustnessCurve` — declarative,
101
+ reproducible experiment workflows that record configs, checkpoints, and
102
+ metrics, and load results back as labeled `xarray` datasets.
103
+ - `MetricsCallback` — a Lightning callback for capturing metrics.
104
+ - `HydraDDP` — a Hydra/Lightning strategy for multi-GPU (DDP) launches.
105
+ - `multirun`, `hydra_list`, `load_experiment`, `load_from_checkpoint` — helpers.
106
+
107
+ ## Install
108
+
109
+ From PyPI (the distribution is named `mushin-py`; you still `import mushin`):
110
+
111
+ ```bash
112
+ pip install mushin-py
113
+ ```
114
+
115
+ For a development environment (runtime deps + dev tooling), this project uses
116
+ [uv](https://docs.astral.sh/uv/):
117
+
118
+ ```bash
119
+ uv sync
120
+ ```
121
+
122
+ Optional runtime extras: `viz` (matplotlib, for `RobustnessCurve` plotting) and
123
+ `netcdf` (netCDF4) — e.g. `pip install "mushin-py[viz]"`.
124
+
125
+ ## Develop
126
+
127
+ ```bash
128
+ uv run pytest tests/ --hypothesis-profile fast # tests (DDP test needs >=2 GPUs)
129
+ uv run ruff check . # lint
130
+ uv run ruff format . # format
131
+ uv run codespell src tests # spell check
132
+ ```
133
+
134
+ Or use the `make` shortcuts (`make help` to list them): `make check` runs
135
+ lint + format-check + spell + tests (what CI runs); `make test-py PYTHON=3.12`
136
+ runs the suite on a specific Python version.
137
+
138
+ Supported Python versions: 3.9 – 3.14.
139
+
140
+ ## Relationship to upstream
141
+
142
+ This is a fork/extraction, not a replacement endorsed by MIT-LL. The configuration
143
+ engine it depends on, `hydra-zen`, is actively maintained by the same group. See
144
+ `LICENSE.txt` for attribution; the original MIT copyright is retained.
@@ -0,0 +1,106 @@
1
+ <p align="center">
2
+ <picture>
3
+ <source media="(prefers-color-scheme: dark)" srcset="logos/mushin-dark.png">
4
+ <img src="logos/mushin-light.png" alt="mushin logo" width="200">
5
+ </picture>
6
+ </p>
7
+
8
+ <h1 align="center">mushin</h1>
9
+
10
+ [![CI](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml/badge.svg)](https://github.com/martinez-hub/mushin/actions/workflows/ci.yml)
11
+ [![PyPI](https://img.shields.io/pypi/v/mushin-py.svg)](https://pypi.org/project/mushin-py/)
12
+ [![Python versions](https://img.shields.io/pypi/pyversions/mushin-py.svg)](https://pypi.org/project/mushin-py/)
13
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE.txt)
14
+
15
+ Boilerplate-free, reproducible machine-learning experiment workflows built on
16
+ [PyTorch Lightning](https://lightning.ai/) and
17
+ [hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen).
18
+
19
+ `mushin` is a standalone carve-out of the `rai_toolbox.mushin` subpackage from
20
+ MIT Lincoln Laboratory's
21
+ [responsible-ai-toolbox](https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox).
22
+ The upstream toolbox is no longer maintained (last release May 2023), but the
23
+ `mushin` workflow layer still works against current versions of its
24
+ dependencies. This package extracts just that layer so it can be maintained and
25
+ used on its own.
26
+
27
+ ## Quickstart: run a sweep, get a dataset
28
+
29
+ Define your experiment as a function, sweep over parameters, and get the results
30
+ back as a labeled `xarray.Dataset` — not rows in a dashboard you have to export.
31
+
32
+ ```python
33
+ import torch as tr
34
+ from mushin import multirun
35
+ from mushin.workflows import MultiRunMetricsWorkflow
36
+
37
+ class LRSweep(MultiRunMetricsWorkflow):
38
+ @staticmethod
39
+ def task(lr: float, seed: int) -> dict:
40
+ tr.manual_seed(seed)
41
+ # ... train a model with this lr/seed ...
42
+ return dict(accuracy=acc) # whatever you return becomes a data variable
43
+
44
+ wf = LRSweep()
45
+ wf.run(lr=multirun([0.01, 0.1, 1.0]), seed=multirun([0, 1, 2])) # 9 runs
46
+
47
+ ds = wf.to_xarray()
48
+ # <xarray.Dataset> Dimensions: (lr: 3, seed: 3)
49
+ # Data variables: accuracy (lr, seed)
50
+
51
+ ds["accuracy"].mean("seed") # average over seeds, per learning rate
52
+ ```
53
+
54
+ The full runnable version is in [`examples/sweep_to_dataset.py`](examples/sweep_to_dataset.py):
55
+
56
+ ```bash
57
+ uv run python examples/sweep_to_dataset.py
58
+ ```
59
+
60
+ ## What it provides
61
+
62
+ - `BaseWorkflow`, `MultiRunMetricsWorkflow`, `RobustnessCurve` — declarative,
63
+ reproducible experiment workflows that record configs, checkpoints, and
64
+ metrics, and load results back as labeled `xarray` datasets.
65
+ - `MetricsCallback` — a Lightning callback for capturing metrics.
66
+ - `HydraDDP` — a Hydra/Lightning strategy for multi-GPU (DDP) launches.
67
+ - `multirun`, `hydra_list`, `load_experiment`, `load_from_checkpoint` — helpers.
68
+
69
+ ## Install
70
+
71
+ From PyPI (the distribution is named `mushin-py`; you still `import mushin`):
72
+
73
+ ```bash
74
+ pip install mushin-py
75
+ ```
76
+
77
+ For a development environment (runtime deps + dev tooling), this project uses
78
+ [uv](https://docs.astral.sh/uv/):
79
+
80
+ ```bash
81
+ uv sync
82
+ ```
83
+
84
+ Optional runtime extras: `viz` (matplotlib, for `RobustnessCurve` plotting) and
85
+ `netcdf` (netCDF4) — e.g. `pip install "mushin-py[viz]"`.
86
+
87
+ ## Develop
88
+
89
+ ```bash
90
+ uv run pytest tests/ --hypothesis-profile fast # tests (DDP test needs >=2 GPUs)
91
+ uv run ruff check . # lint
92
+ uv run ruff format . # format
93
+ uv run codespell src tests # spell check
94
+ ```
95
+
96
+ Or use the `make` shortcuts (`make help` to list them): `make check` runs
97
+ lint + format-check + spell + tests (what CI runs); `make test-py PYTHON=3.12`
98
+ runs the suite on a specific Python version.
99
+
100
+ Supported Python versions: 3.9 – 3.14.
101
+
102
+ ## Relationship to upstream
103
+
104
+ This is a fork/extraction, not a replacement endorsed by MIT-LL. The configuration
105
+ engine it depends on, `hydra-zen`, is actively maintained by the same group. See
106
+ `LICENSE.txt` for attribution; the original MIT copyright is retained.
@@ -0,0 +1,100 @@
1
+ [build-system]
2
+ requires = ["setuptools>=82.0.1", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ # Distribution name on PyPI is "mushin-py" (the "mushin" name was taken);
7
+ # the import package is still `import mushin` (see src/mushin/).
8
+ name = "mushin-py"
9
+ version = "0.1.0"
10
+ description = "Boilerplate-free, reproducible ML experiment workflows built on PyTorch Lightning and hydra-zen. Carved out of MIT-LL's responsible-ai-toolbox."
11
+ readme = "README.md"
12
+ requires-python = ">=3.9"
13
+ license = { text = "MIT" }
14
+ authors = [
15
+ { name = "Massachusetts Institute of Technology (original rai_toolbox.mushin authors)" },
16
+ ]
17
+ keywords = ["machine learning", "reproducibility", "pytorch-lightning", "hydra", "hydra-zen"]
18
+ classifiers = [
19
+ "Programming Language :: Python :: 3.9",
20
+ "Programming Language :: Python :: 3.10",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ "Programming Language :: Python :: 3.13",
24
+ "Programming Language :: Python :: 3.14",
25
+ "License :: OSI Approved :: MIT License",
26
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
27
+ ]
28
+ dependencies = [
29
+ "pytorch-lightning >= 1.5.0",
30
+ "hydra-zen >= 0.9.0",
31
+ "hydra-core >= 1.2.0",
32
+ # torch dropped Intel-macOS (x86_64) wheels after 2.2.x, so cap there on that
33
+ # platform only; everywhere else take any modern torch.
34
+ "torch >= 1.13.0,<2.3 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
35
+ "torch >= 1.13.0 ; sys_platform != 'darwin' or platform_machine != 'x86_64'",
36
+ # torch 2.2.x (the Intel-macOS cap above) is built against NumPy 1.x; that
37
+ # platform is inherently Python <=3.11. Everywhere else take NumPy 2.x.
38
+ "numpy >= 1.24,<2 ; sys_platform == 'darwin' and platform_machine == 'x86_64'",
39
+ "numpy >= 2 ; sys_platform != 'darwin' or platform_machine != 'x86_64'",
40
+ "omegaconf >= 2.1.1",
41
+ "xarray >= 0.19.0",
42
+ "typing-extensions >= 4.1.0",
43
+ "torchmetrics >= 1.0",
44
+ "scipy >= 1.10",
45
+ "pandas >= 1.5",
46
+ ]
47
+
48
+ [project.optional-dependencies]
49
+ viz = ["matplotlib >= 3.3"]
50
+ netcdf = ["netCDF4 >= 1.5.8"]
51
+
52
+ [project.urls]
53
+ "Upstream (origin)" = "https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox"
54
+
55
+ # Dev tooling, installed by `uv sync` (PEP 735 dependency group).
56
+ [dependency-groups]
57
+ dev = [
58
+ "ruff >= 0.6",
59
+ "codespell >= 2.3",
60
+ "pre-commit >= 3.5",
61
+ "pytest >= 7.0.0",
62
+ "hypothesis >= 6.28.0",
63
+ "pytest-xdist",
64
+ "matplotlib >= 3.3",
65
+ "netCDF4 >= 1.5.8",
66
+ ]
67
+
68
+ [tool.setuptools.packages.find]
69
+ where = ["src"]
70
+
71
+ [tool.pytest.ini_options]
72
+ testpaths = ["tests"]
73
+ pythonpath = ["examples"]
74
+ filterwarnings = [
75
+ # mushin deliberately relies on Hydra's chdir-into-job-dir behavior to read
76
+ # results back; this deprecation is about that behavior changing upstream.
77
+ "ignore:Future Hydra versions will no longer change working directory:UserWarning",
78
+ # matplotlib internals call deprecated pyparsing APIs; not fixable from here.
79
+ "ignore:.* deprecated - use .*",
80
+ ]
81
+
82
+ [tool.ruff]
83
+ line-length = 88
84
+ target-version = "py39"
85
+ src = ["src", "tests"]
86
+
87
+ [tool.ruff.lint]
88
+ # E/W: pycodestyle, F: pyflakes, I: isort, UP: pyupgrade, B: bugbear
89
+ select = ["E", "W", "F", "I", "UP", "B"]
90
+ ignore = [
91
+ "E501", # line length is handled by the formatter
92
+ ]
93
+
94
+ [tool.ruff.lint.per-file-ignores]
95
+ "tests/*" = ["B"] # bugbear is noisy for test fixtures/assertions
96
+
97
+ [tool.codespell]
98
+ skip = "*.lock,*.ckpt,*.pt,.git,.venv,*.egg-info"
99
+ # "mushin" (無心) and a few ML terms trip the default dictionary.
100
+ ignore-words-list = "mushin,nin,tha"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,25 @@
1
+ # Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2
+ # Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ from ._utils import load_experiment, load_from_checkpoint
6
+ from .lightning import HydraDDP, MetricsCallback
7
+ from .workflows import (
8
+ BaseWorkflow,
9
+ MultiRunMetricsWorkflow,
10
+ RobustnessCurve,
11
+ hydra_list,
12
+ multirun,
13
+ )
14
+
15
+ __all__ = [
16
+ "load_experiment",
17
+ "load_from_checkpoint",
18
+ "MetricsCallback",
19
+ "MultiRunMetricsWorkflow",
20
+ "HydraDDP",
21
+ "RobustnessCurve",
22
+ "BaseWorkflow",
23
+ "multirun",
24
+ "hydra_list",
25
+ ]
@@ -0,0 +1,30 @@
1
+ # Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2
+ # Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ from typing import Final, NamedTuple
6
+
7
+ import pytorch_lightning
8
+
9
+
10
+ class Version(NamedTuple):
11
+ major: int
12
+ minor: int
13
+ patch: int
14
+
15
+
16
+ def _get_version(ver_str: str) -> Version:
17
+ # Not for general use. Tested only for Hydra and OmegaConf
18
+ # version string styles
19
+
20
+ splits = ver_str.split(".")[:3]
21
+ if not len(splits) == 3: # pragma: no cover
22
+ raise ValueError(f"Version string {ver_str} couldn't be parsed")
23
+
24
+ major, minor = (int(v) for v in splits[:2])
25
+ patch_str, *_ = splits[-1].split("rc")
26
+
27
+ return Version(major=major, minor=minor, patch=int(patch_str))
28
+
29
+
30
+ PL_VERSION: Final = _get_version(pytorch_lightning.__version__)
@@ -0,0 +1,145 @@
1
+ # Copyright 2023, MASSACHUSETTS INSTITUTE OF TECHNOLOGY
2
+ # Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014).
3
+ # SPDX-License-Identifier: MIT
4
+
5
+ import logging
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, Optional, Union
9
+
10
+ import torch
11
+ from hydra_zen import load_from_yaml
12
+ from omegaconf import DictConfig, ListConfig
13
+ from torch import nn
14
+
15
+ log = logging.getLogger(__name__)
16
+
17
+
18
+ def load_from_checkpoint(
19
+ model: nn.Module,
20
+ *,
21
+ ckpt: Optional[Union[str, Path]] = None,
22
+ weights_key: Optional[str] = None,
23
+ weights_key_strip: Optional[str] = None,
24
+ model_attr: Optional[str] = None,
25
+ ) -> nn.Module:
26
+ """Load model weights.
27
+
28
+ Parameters
29
+ ----------
30
+ model : Module
31
+ The PyTorch Module
32
+
33
+ ckpt : Optional[Union[str, Path]]
34
+ The path to the file containing the model weights. If no path is provided
35
+ the model will not be updated.
36
+
37
+ weights_key : Optional[str] (default: "state_dict")
38
+ (load_module=False) The key from the checkpoint file containing the model
39
+ weights.
40
+
41
+ weights_key_strip : Optional[str] (default: "model")
42
+ (load_module=False) The prefix to remove from each weight's key prior
43
+ to loading the module.
44
+
45
+ model_attr : Optional[str] (default: "model")
46
+ (load_module=False) The attribute of the module containing the `torch.nn.Module`
47
+
48
+ Returns
49
+ -------
50
+ module : LightningModule
51
+ """
52
+ if ckpt is None:
53
+ return model
54
+
55
+ ckpt = Path(str(ckpt))
56
+ if not ckpt.exists():
57
+ ckpt = Path.home() / ".torch" / "models" / ckpt
58
+ log.info(f"Loading model checkpoint from {ckpt}")
59
+
60
+ # weights_only=False: these are trusted, self-produced checkpoints that may
61
+ # hold more than tensors. torch 2.6 flipped this default to True.
62
+ ckpt_data: dict[str, Any] = torch.load(ckpt, map_location="cpu", weights_only=False)
63
+
64
+ if weights_key is not None:
65
+ assert weights_key in ckpt_data
66
+ ckpt_data = ckpt_data[weights_key]
67
+
68
+ if weights_key_strip:
69
+ if not weights_key_strip.endswith("."):
70
+ weights_key_strip = weights_key_strip + "."
71
+
72
+ ckpt_data = {
73
+ k[len(weights_key_strip) :]: v
74
+ for k, v in ckpt_data.items()
75
+ if k.startswith(weights_key_strip)
76
+ }
77
+
78
+ if model_attr is None:
79
+ # The weights can be loaded in directly
80
+ model.load_state_dict(ckpt_data)
81
+
82
+ else:
83
+ assert hasattr(model, model_attr)
84
+ getattr(model, model_attr).load_state_dict(ckpt_data)
85
+
86
+ return model
87
+
88
+
89
+ @dataclass
90
+ class Experiment:
91
+ working_dir: str
92
+ cfg: Optional[Union[dict, ListConfig, DictConfig]]
93
+ ckpts: list[str]
94
+ metrics: dict
95
+
96
+
97
+ def load_experiment(
98
+ exp_path: Union[str, Path], search_path: Optional[Union[str, Path]] = None
99
+ ) -> Union[Experiment, list[Experiment]]:
100
+ """Loads all configuration and metrics outputs in an experiment directory.
101
+
102
+ Parameters
103
+ ----------
104
+ exp_path: Union[str, Path]
105
+ The directory to search for data. Directory must include the
106
+ ".hydra/config.yaml" file.
107
+
108
+ Returns
109
+ ----------
110
+ exps: Union[Experiment, List[Experiment]]
111
+
112
+ """
113
+ assert Path(exp_path).exists(), f"{exp_path} not found"
114
+
115
+ # first find all .hydra files
116
+ if search_path is None:
117
+ search_path = ".hydra"
118
+ cfg_files = sorted(Path(exp_path).absolute().glob(f"**/{str(search_path)}"))
119
+
120
+ # For each file load metrics data
121
+ exps = []
122
+ for path in cfg_files:
123
+ # Save experiment configuration
124
+ cfg_files = list(path.parent.glob("**/config.yaml"))
125
+ cfg = None
126
+ if len(cfg_files) == 1:
127
+ cfg = load_from_yaml(cfg_files[0])
128
+
129
+ # Load metrics files
130
+ files = path.parent.glob("*.pt")
131
+ metrics = dict()
132
+ for f in files:
133
+ name = f.name
134
+ metrics[name[:-3]] = torch.load(f, weights_only=False)
135
+
136
+ # Load path to checkpoints
137
+ ckpts = [str(ckpt.resolve()) for ckpt in path.parent.glob("**/*.ckpt")]
138
+
139
+ # Append experiment to list
140
+ exps.append(Experiment(str(path.parent.parent), cfg, ckpts, metrics))
141
+
142
+ if len(exps) == 1:
143
+ return exps[0]
144
+
145
+ return exps