hydraflow 0.7.5__tar.gz → 0.8.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.7.5 → hydraflow-0.8.0}/PKG-INFO +2 -2
- {hydraflow-0.7.5 → hydraflow-0.8.0}/README.md +1 -1
- {hydraflow-0.7.5 → hydraflow-0.8.0}/apps/quickstart.py +5 -2
- {hydraflow-0.7.5 → hydraflow-0.8.0}/docs/usage/quickstart.md +0 -12
- {hydraflow-0.7.5 → hydraflow-0.8.0}/pyproject.toml +10 -14
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/__init__.py +1 -16
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/config.py +10 -27
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/context.py +6 -49
- hydraflow-0.8.0/src/hydraflow/main.py +162 -0
- hydraflow-0.8.0/src/hydraflow/mlflow.py +167 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/param.py +2 -2
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/run_collection.py +10 -156
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/run_data.py +4 -2
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/utils.py +19 -28
- hydraflow-0.8.0/tests/config/test_config.py +54 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/config/test_params.py +11 -27
- hydraflow-0.8.0/tests/conftest.py +46 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/context/chdir.py +5 -2
- hydraflow-0.7.5/tests/context/logging.py → hydraflow-0.8.0/tests/context/log_run.py +11 -8
- hydraflow-0.7.5/tests/context/context.py → hydraflow-0.8.0/tests/context/start_run.py +7 -13
- hydraflow-0.7.5/tests/context/test_logging.py → hydraflow-0.8.0/tests/context/test_log_run.py +23 -19
- hydraflow-0.7.5/tests/context/test_context.py → hydraflow-0.8.0/tests/context/test_start_run.py +13 -5
- hydraflow-0.7.5/tests/main/base.py → hydraflow-0.8.0/tests/main/default.py +2 -1
- hydraflow-0.8.0/tests/main/match_overrides.py +24 -0
- hydraflow-0.7.5/tests/main/restart.py → hydraflow-0.8.0/tests/main/rerun_finished.py +9 -2
- hydraflow-0.7.5/tests/main/test_base.py → hydraflow-0.8.0/tests/main/test_default.py +10 -6
- hydraflow-0.8.0/tests/main/test_match_overrides.py +22 -0
- hydraflow-0.7.5/tests/main/test_restart.py → hydraflow-0.8.0/tests/main/test_rerun_finished.py +1 -1
- hydraflow-0.7.5/tests/main/test_skip.py → hydraflow-0.8.0/tests/main/test_skip_finished.py +1 -1
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/params.py +5 -2
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/test_params.py +14 -3
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/filter.py +6 -3
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/test_collection.py +26 -69
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/test_data.py +2 -2
- hydraflow-0.8.0/tests/run/test_filter.py +41 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/test_info.py +2 -2
- hydraflow-0.8.0/tests/run/test_values.py +34 -0
- hydraflow-0.7.5/tests/run/run.py → hydraflow-0.8.0/tests/run/values.py +7 -9
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/test_mlflow.py +8 -26
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/test_run.py +2 -2
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/test_utils.py +1 -17
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/utils.py +4 -5
- hydraflow-0.7.5/src/hydraflow/main.py +0 -54
- hydraflow-0.7.5/src/hydraflow/mlflow.py +0 -280
- hydraflow-0.7.5/tests/config/overrides.py +0 -32
- hydraflow-0.7.5/tests/config/test_config.py +0 -29
- hydraflow-0.7.5/tests/config/test_overrides.py +0 -25
- hydraflow-0.7.5/tests/conftest.py +0 -81
- hydraflow-0.7.5/tests/context/rerun.py +0 -40
- hydraflow-0.7.5/tests/context/test_rerun.py +0 -31
- hydraflow-0.7.5/tests/run/test_filter.py +0 -19
- hydraflow-0.7.5/tests/run/test_run.py +0 -54
- {hydraflow-0.7.5 → hydraflow-0.8.0}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/.gitattributes +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/.github/workflows/ci.yaml +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/.github/workflows/docs.yaml +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/.gitignore +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/LICENSE +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/docs/index.md +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/hydraflow.yaml +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/mkdocs.yaml +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/cli.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/py.typed +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/run_info.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/__init__.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/__init__.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/conftest.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/test_run.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/test_show.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/test_version.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/config/__init__.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/context/__init__.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/context/test_chdir.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/main/__init__.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/main/force_new_run.py +0 -0
- /hydraflow-0.7.5/tests/main/skip.py → /hydraflow-0.8.0/tests/main/skip_finished.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/main/test_force_new_run.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/__init__.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/test_param.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/__init__.py +0 -0
- {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.0
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -108,7 +108,7 @@ class MySQLConfig:
|
|
108
108
|
cs = ConfigStore.instance()
|
109
109
|
cs.store(name="config", node=MySQLConfig)
|
110
110
|
|
111
|
-
@hydra.main(
|
111
|
+
@hydra.main(config_name="config", version_base=None)
|
112
112
|
def my_app(cfg: MySQLConfig) -> None:
|
113
113
|
# Set experiment by Hydra job name.
|
114
114
|
hydraflow.set_experiment()
|
@@ -63,7 +63,7 @@ class MySQLConfig:
|
|
63
63
|
cs = ConfigStore.instance()
|
64
64
|
cs.store(name="config", node=MySQLConfig)
|
65
65
|
|
66
|
-
@hydra.main(
|
66
|
+
@hydra.main(config_name="config", version_base=None)
|
67
67
|
def my_app(cfg: MySQLConfig) -> None:
|
68
68
|
# Set experiment by Hydra job name.
|
69
69
|
hydraflow.set_experiment()
|
@@ -2,7 +2,9 @@ import logging
|
|
2
2
|
from dataclasses import dataclass
|
3
3
|
|
4
4
|
import hydra
|
5
|
+
import mlflow
|
5
6
|
from hydra.core.config_store import ConfigStore
|
7
|
+
from hydra.core.hydra_config import HydraConfig
|
6
8
|
|
7
9
|
import hydraflow
|
8
10
|
|
@@ -19,9 +21,10 @@ cs = ConfigStore.instance()
|
|
19
21
|
cs.store(name="config", node=Config)
|
20
22
|
|
21
23
|
|
22
|
-
@hydra.main(
|
24
|
+
@hydra.main(config_name="config", version_base=None)
|
23
25
|
def app(cfg: Config) -> None:
|
24
|
-
|
26
|
+
hc = HydraConfig.get()
|
27
|
+
mlflow.set_experiment(hc.job.name)
|
25
28
|
|
26
29
|
with hydraflow.start_run(cfg):
|
27
30
|
log.info(f"{cfg.width=}, {cfg.height=}")
|
@@ -117,18 +117,6 @@ $ python apps/quickstart.py -m width=400,600 height=100,200,300
|
|
117
117
|
>>> print(run.data.params)
|
118
118
|
```
|
119
119
|
|
120
|
-
### Map runs
|
121
|
-
|
122
|
-
```pycon exec="1" source="console" session="quickstart"
|
123
|
-
>>> params = rc.map(lambda x: x.data.params)
|
124
|
-
>>> for p in params:
|
125
|
-
... print(p)
|
126
|
-
```
|
127
|
-
|
128
|
-
```pycon exec="1" source="console" session="quickstart"
|
129
|
-
>>> list(rc.map_id(print))
|
130
|
-
```
|
131
|
-
|
132
120
|
### Group runs
|
133
121
|
|
134
122
|
```pycon exec="1" source="console" session="quickstart"
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.
|
7
|
+
version = "0.8.0"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = { file = "LICENSE" }
|
@@ -72,24 +72,20 @@ ignore = [
|
|
72
72
|
"D203",
|
73
73
|
"D213",
|
74
74
|
"EM101",
|
75
|
+
"FBT001",
|
76
|
+
"FBT002",
|
75
77
|
"PGH003",
|
78
|
+
"PLR0911",
|
79
|
+
"PLR0913",
|
76
80
|
"PLR1704",
|
81
|
+
"PLR2004",
|
82
|
+
"SIM102",
|
83
|
+
"SIM108",
|
77
84
|
"TRY003",
|
78
85
|
]
|
79
86
|
|
80
87
|
[tool.ruff.lint.per-file-ignores]
|
81
|
-
"tests/*" = [
|
82
|
-
"A001",
|
83
|
-
"ANN",
|
84
|
-
"ARG",
|
85
|
-
"D",
|
86
|
-
"FBT",
|
87
|
-
"PLR",
|
88
|
-
"PT",
|
89
|
-
"S",
|
90
|
-
"SIM108",
|
91
|
-
"SLF",
|
92
|
-
]
|
88
|
+
"tests/*" = ["A001", "ANN", "ARG", "D", "FBT", "PD", "PLR", "PT", "S", "SLF"]
|
93
89
|
"apps/*.py" = ["D", "G", "INP"]
|
94
|
-
"src/hydraflow/main.py" = ["ANN201", "D401"
|
90
|
+
"src/hydraflow/main.py" = ["ANN201", "D401"]
|
95
91
|
"src/hydraflow/cli.py" = ["ANN", "D"]
|
@@ -1,23 +1,14 @@
|
|
1
1
|
"""Integrate Hydra and MLflow to manage and track machine learning experiments."""
|
2
2
|
|
3
|
-
from hydraflow.config import select_config, select_overrides
|
4
3
|
from hydraflow.context import chdir_artifact, log_run, start_run
|
5
4
|
from hydraflow.main import main
|
6
|
-
from hydraflow.mlflow import
|
7
|
-
list_run_ids,
|
8
|
-
list_run_paths,
|
9
|
-
list_runs,
|
10
|
-
search_runs,
|
11
|
-
set_experiment,
|
12
|
-
)
|
5
|
+
from hydraflow.mlflow import list_run_ids, list_run_paths, list_runs
|
13
6
|
from hydraflow.run_collection import RunCollection
|
14
7
|
from hydraflow.utils import (
|
15
8
|
get_artifact_dir,
|
16
9
|
get_artifact_path,
|
17
10
|
get_hydra_output_dir,
|
18
|
-
get_overrides,
|
19
11
|
load_config,
|
20
|
-
load_overrides,
|
21
12
|
remove_run,
|
22
13
|
)
|
23
14
|
|
@@ -27,18 +18,12 @@ __all__ = [
|
|
27
18
|
"get_artifact_dir",
|
28
19
|
"get_artifact_path",
|
29
20
|
"get_hydra_output_dir",
|
30
|
-
"get_overrides",
|
31
21
|
"list_run_ids",
|
32
22
|
"list_run_paths",
|
33
23
|
"list_runs",
|
34
24
|
"load_config",
|
35
|
-
"load_overrides",
|
36
25
|
"log_run",
|
37
26
|
"main",
|
38
27
|
"remove_run",
|
39
|
-
"search_runs",
|
40
|
-
"select_config",
|
41
|
-
"select_overrides",
|
42
|
-
"set_experiment",
|
43
28
|
"start_run",
|
44
29
|
]
|
@@ -6,35 +6,19 @@ from typing import TYPE_CHECKING
|
|
6
6
|
|
7
7
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
8
8
|
|
9
|
-
from hydraflow.utils import get_overrides
|
10
|
-
|
11
9
|
if TYPE_CHECKING:
|
12
10
|
from collections.abc import Iterator
|
13
11
|
from typing import Any
|
14
12
|
|
15
13
|
|
16
|
-
def
|
17
|
-
"""Iterate over parameters and collect them into a dictionary.
|
18
|
-
|
19
|
-
Args:
|
20
|
-
config (object): The configuration object to iterate over.
|
21
|
-
prefix (str): The prefix to prepend to the parameter keys.
|
22
|
-
|
23
|
-
Returns:
|
24
|
-
dict[str, Any]: A dictionary of collected parameters.
|
25
|
-
|
26
|
-
"""
|
27
|
-
return dict(iter_params(config))
|
28
|
-
|
29
|
-
|
30
|
-
def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
14
|
+
def iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
31
15
|
"""Recursively iterate over the parameters in the given configuration object.
|
32
16
|
|
33
17
|
This function traverses the configuration object and yields key-value pairs
|
34
18
|
representing the parameters. The keys are prefixed with the provided prefix.
|
35
19
|
|
36
20
|
Args:
|
37
|
-
config (
|
21
|
+
config (Any): The configuration object to iterate over. This can be a
|
38
22
|
dictionary, list, DictConfig, or ListConfig.
|
39
23
|
prefix (str): The prefix to prepend to the parameter keys.
|
40
24
|
Defaults to an empty string.
|
@@ -50,7 +34,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
50
34
|
config = _from_dotlist(config)
|
51
35
|
|
52
36
|
if not isinstance(config, DictConfig | ListConfig):
|
53
|
-
config = OmegaConf.create(config)
|
37
|
+
config = OmegaConf.create(config)
|
54
38
|
|
55
39
|
yield from _iter_params(config, prefix)
|
56
40
|
|
@@ -65,7 +49,7 @@ def _from_dotlist(config: list[str]) -> dict[str, str]:
|
|
65
49
|
return result
|
66
50
|
|
67
51
|
|
68
|
-
def _iter_params(config:
|
52
|
+
def _iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
69
53
|
if isinstance(config, DictConfig):
|
70
54
|
for key, value in config.items():
|
71
55
|
if _is_param(value):
|
@@ -83,12 +67,12 @@ def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
83
67
|
yield from _iter_params(value, f"{prefix}{index}.")
|
84
68
|
|
85
69
|
|
86
|
-
def _is_param(value:
|
70
|
+
def _is_param(value: Any) -> bool:
|
87
71
|
"""Check if the given value is a parameter."""
|
88
72
|
if isinstance(value, DictConfig):
|
89
73
|
return False
|
90
74
|
|
91
|
-
if isinstance(value, ListConfig):
|
75
|
+
if isinstance(value, ListConfig):
|
92
76
|
if any(isinstance(v, DictConfig | ListConfig) for v in value):
|
93
77
|
return False
|
94
78
|
|
@@ -103,14 +87,14 @@ def _convert(value: Any) -> Any:
|
|
103
87
|
return value
|
104
88
|
|
105
89
|
|
106
|
-
def select_config(config:
|
90
|
+
def select_config(config: Any, names: list[str]) -> dict[str, Any]:
|
107
91
|
"""Select the given parameters from the configuration object.
|
108
92
|
|
109
93
|
This function selects the given parameters from the configuration object
|
110
94
|
and returns a new configuration object containing only the selected parameters.
|
111
95
|
|
112
96
|
Args:
|
113
|
-
config (
|
97
|
+
config (Any): The configuration object to select parameters from.
|
114
98
|
names (list[str]): The names of the parameters to select.
|
115
99
|
|
116
100
|
Returns:
|
@@ -120,7 +104,7 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
|
|
120
104
|
if not isinstance(config, DictConfig):
|
121
105
|
config = OmegaConf.structured(config)
|
122
106
|
|
123
|
-
return {name: _get(config, name) for name in names}
|
107
|
+
return {name: _get(config, name) for name in names}
|
124
108
|
|
125
109
|
|
126
110
|
def _get(config: DictConfig, name: str) -> Any:
|
@@ -132,8 +116,7 @@ def _get(config: DictConfig, name: str) -> Any:
|
|
132
116
|
return _get(config.get(prefix), name)
|
133
117
|
|
134
118
|
|
135
|
-
def select_overrides(config: object) -> dict[str, Any]:
|
119
|
+
def select_overrides(config: object, overrides: list[str]) -> dict[str, Any]:
|
136
120
|
"""Select the given overrides from the configuration object."""
|
137
|
-
overrides = get_overrides()
|
138
121
|
names = [override.split("=")[0].strip() for override in overrides]
|
139
122
|
return select_config(config, names)
|
@@ -12,7 +12,7 @@ import mlflow
|
|
12
12
|
import mlflow.artifacts
|
13
13
|
from hydra.core.hydra_config import HydraConfig
|
14
14
|
|
15
|
-
from hydraflow.mlflow import log_params
|
15
|
+
from hydraflow.mlflow import log_params, log_text
|
16
16
|
from hydraflow.utils import get_artifact_dir
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
@@ -55,11 +55,11 @@ def log_run(
|
|
55
55
|
log_params(config, synchronous=synchronous)
|
56
56
|
|
57
57
|
hc = HydraConfig.get()
|
58
|
-
|
58
|
+
hydra_dir = Path(hc.runtime.output_dir)
|
59
59
|
|
60
60
|
# Save '.hydra' config directory.
|
61
|
-
|
62
|
-
mlflow.log_artifacts(
|
61
|
+
hydra_subdir = hydra_dir / (hc.output_subdir or "")
|
62
|
+
mlflow.log_artifacts(hydra_subdir.as_posix(), hc.output_subdir)
|
63
63
|
|
64
64
|
try:
|
65
65
|
yield
|
@@ -70,43 +70,14 @@ def log_run(
|
|
70
70
|
raise
|
71
71
|
|
72
72
|
finally:
|
73
|
-
log_text(
|
74
|
-
|
75
|
-
|
76
|
-
def log_text(directory: Path, pattern: str = "*.log") -> None:
|
77
|
-
"""Log text files in the given directory as artifacts.
|
78
|
-
|
79
|
-
Append the text files to the existing text file in the artifact directory.
|
80
|
-
|
81
|
-
Args:
|
82
|
-
directory (Path): The directory to find the logs in.
|
83
|
-
pattern (str): The pattern to match the logs.
|
84
|
-
|
85
|
-
"""
|
86
|
-
artifact_dir = get_artifact_dir()
|
87
|
-
|
88
|
-
for file in directory.glob(pattern):
|
89
|
-
if not file.is_file():
|
90
|
-
continue
|
91
|
-
|
92
|
-
file_artifact = artifact_dir / file.name
|
93
|
-
if file_artifact.exists():
|
94
|
-
text = file_artifact.read_text()
|
95
|
-
if not text.endswith("\n"):
|
96
|
-
text += "\n"
|
97
|
-
else:
|
98
|
-
text = ""
|
99
|
-
|
100
|
-
text += file.read_text()
|
101
|
-
mlflow.log_text(text, file.name)
|
73
|
+
log_text(hydra_dir)
|
102
74
|
|
103
75
|
|
104
76
|
@contextmanager
|
105
|
-
def start_run(
|
77
|
+
def start_run(
|
106
78
|
config: object,
|
107
79
|
*,
|
108
80
|
chdir: bool = False,
|
109
|
-
run: Run | None = None,
|
110
81
|
run_id: str | None = None,
|
111
82
|
experiment_id: str | None = None,
|
112
83
|
run_name: str | None = None,
|
@@ -126,7 +97,6 @@ def start_run( # noqa: PLR0913
|
|
126
97
|
config (object): The configuration object to log parameters from.
|
127
98
|
chdir (bool): Whether to change the current working directory to the
|
128
99
|
artifact directory of the current run. Defaults to False.
|
129
|
-
run (Run | None): The existing run. Defaults to None.
|
130
100
|
run_id (str | None): The existing run ID. Defaults to None.
|
131
101
|
experiment_id (str | None): The experiment ID. Defaults to None.
|
132
102
|
run_name (str | None): The name of the run. Defaults to None.
|
@@ -142,20 +112,7 @@ def start_run( # noqa: PLR0913
|
|
142
112
|
Yields:
|
143
113
|
Run: An MLflow Run object representing the started run.
|
144
114
|
|
145
|
-
Example:
|
146
|
-
with start_run(config) as run:
|
147
|
-
# Perform operations within the MLflow run context
|
148
|
-
pass
|
149
|
-
|
150
|
-
See Also:
|
151
|
-
- `mlflow.start_run`: The MLflow function to start a run directly.
|
152
|
-
- `log_run`: A context manager to log parameters and manage the MLflow
|
153
|
-
run context.
|
154
|
-
|
155
115
|
"""
|
156
|
-
if run:
|
157
|
-
run_id = run.info.run_id
|
158
|
-
|
159
116
|
with (
|
160
117
|
mlflow.start_run(
|
161
118
|
run_id=run_id,
|
@@ -0,0 +1,162 @@
|
|
1
|
+
"""Integration of MLflow experiment tracking with Hydra configuration management.
|
2
|
+
|
3
|
+
This module provides decorators and utilities to seamlessly combine Hydra's
|
4
|
+
configuration management with MLflow's experiment tracking capabilities. It
|
5
|
+
enables automatic run deduplication, configuration storage, and experiment
|
6
|
+
management.
|
7
|
+
|
8
|
+
The main functionality is provided through the `main` decorator, which can be
|
9
|
+
used to wrap experiment entry points. This decorator handles:
|
10
|
+
- Configuration management via Hydra
|
11
|
+
- Experiment tracking via MLflow
|
12
|
+
- Run deduplication based on configurations
|
13
|
+
- Working directory management
|
14
|
+
- Automatic configuration storage
|
15
|
+
|
16
|
+
Example:
|
17
|
+
```python
|
18
|
+
from dataclasses import dataclass
|
19
|
+
from mlflow.entities import Run
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class Config:
|
23
|
+
learning_rate: float
|
24
|
+
batch_size: int
|
25
|
+
|
26
|
+
@main(Config)
|
27
|
+
def train(run: Run, config: Config):
|
28
|
+
# Your training code here
|
29
|
+
pass
|
30
|
+
```
|
31
|
+
|
32
|
+
"""
|
33
|
+
|
34
|
+
from __future__ import annotations
|
35
|
+
|
36
|
+
from functools import wraps
|
37
|
+
from typing import TYPE_CHECKING, TypeVar
|
38
|
+
|
39
|
+
import hydra
|
40
|
+
import mlflow
|
41
|
+
from hydra.core.config_store import ConfigStore
|
42
|
+
from hydra.core.hydra_config import HydraConfig
|
43
|
+
from mlflow.entities import RunStatus
|
44
|
+
from omegaconf import OmegaConf
|
45
|
+
|
46
|
+
import hydraflow
|
47
|
+
from hydraflow.utils import file_uri_to_path
|
48
|
+
|
49
|
+
if TYPE_CHECKING:
|
50
|
+
from collections.abc import Callable
|
51
|
+
from pathlib import Path
|
52
|
+
|
53
|
+
from mlflow.entities import Run
|
54
|
+
|
55
|
+
FINISHED = RunStatus.to_string(RunStatus.FINISHED)
|
56
|
+
|
57
|
+
T = TypeVar("T")
|
58
|
+
|
59
|
+
|
60
|
+
def main(
|
61
|
+
node: T | type[T],
|
62
|
+
config_name: str = "config",
|
63
|
+
*,
|
64
|
+
chdir: bool = False,
|
65
|
+
force_new_run: bool = False,
|
66
|
+
match_overrides: bool = False,
|
67
|
+
rerun_finished: bool = False,
|
68
|
+
):
|
69
|
+
"""Decorator for configuring and running MLflow experiments with Hydra.
|
70
|
+
|
71
|
+
This decorator combines Hydra configuration management with MLflow experiment
|
72
|
+
tracking. It automatically handles run deduplication and configuration storage.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
node: Configuration node class or instance defining the structure of the
|
76
|
+
configuration.
|
77
|
+
config_name: Name of the configuration. Defaults to "config".
|
78
|
+
chdir: If True, changes working directory to the artifact directory
|
79
|
+
of the run. Defaults to False.
|
80
|
+
force_new_run: If True, always creates a new MLflow run instead of
|
81
|
+
reusing existing ones. Defaults to False.
|
82
|
+
match_overrides: If True, matches runs based on Hydra CLI overrides
|
83
|
+
instead of full config. Defaults to False.
|
84
|
+
rerun_finished: If True, allows rerunning completed runs. Defaults to
|
85
|
+
False.
|
86
|
+
|
87
|
+
"""
|
88
|
+
|
89
|
+
def decorator(app: Callable[[Run, T], None]) -> Callable[[], None]:
|
90
|
+
ConfigStore.instance().store(config_name, node)
|
91
|
+
|
92
|
+
@hydra.main(config_name=config_name, version_base=None)
|
93
|
+
@wraps(app)
|
94
|
+
def inner_decorator(config: T) -> None:
|
95
|
+
hc = HydraConfig.get()
|
96
|
+
experiment = mlflow.set_experiment(hc.job.name)
|
97
|
+
|
98
|
+
if force_new_run:
|
99
|
+
run_id = None
|
100
|
+
else:
|
101
|
+
uri = experiment.artifact_location
|
102
|
+
overrides = hc.overrides.task if match_overrides else None
|
103
|
+
run_id = get_run_id(uri, config, overrides)
|
104
|
+
|
105
|
+
if run_id and not rerun_finished:
|
106
|
+
run = mlflow.get_run(run_id)
|
107
|
+
if run.info.status == FINISHED:
|
108
|
+
return
|
109
|
+
|
110
|
+
with hydraflow.start_run(config, run_id=run_id, chdir=chdir) as run:
|
111
|
+
app(run, config)
|
112
|
+
|
113
|
+
return inner_decorator
|
114
|
+
|
115
|
+
return decorator
|
116
|
+
|
117
|
+
|
118
|
+
def get_run_id(uri: str, config: object, overrides: list[str] | None) -> str | None:
|
119
|
+
"""Try to get the run ID for the given configuration.
|
120
|
+
|
121
|
+
If the run is not found, the function will return None.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
uri (str): The URI of the experiment.
|
125
|
+
config (object): The configuration object.
|
126
|
+
overrides (list[str] | None): The task overrides.
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
The run ID for the given configuration or overrides. Returns None if
|
130
|
+
no run ID is found.
|
131
|
+
|
132
|
+
"""
|
133
|
+
for run_dir in file_uri_to_path(uri).iterdir():
|
134
|
+
if run_dir.is_dir() and equals(run_dir, config, overrides):
|
135
|
+
return run_dir.name
|
136
|
+
|
137
|
+
return None
|
138
|
+
|
139
|
+
|
140
|
+
def equals(run_dir: Path, config: object, overrides: list[str] | None) -> bool:
|
141
|
+
"""Check if the run directory matches the given configuration or overrides.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
run_dir (Path): The run directory.
|
145
|
+
config (object): The configuration object.
|
146
|
+
overrides (list[str] | None): The task overrides.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
True if the run directory matches the given configuration or overrides,
|
150
|
+
False otherwise.
|
151
|
+
|
152
|
+
"""
|
153
|
+
if overrides is None:
|
154
|
+
path = run_dir / "artifacts/.hydra/config.yaml"
|
155
|
+
else:
|
156
|
+
path = run_dir / "artifacts/.hydra/overrides.yaml"
|
157
|
+
config = overrides
|
158
|
+
|
159
|
+
if not path.exists():
|
160
|
+
return False
|
161
|
+
|
162
|
+
return OmegaConf.load(path) == config
|
@@ -0,0 +1,167 @@
|
|
1
|
+
"""Integration of MLflow experiment tracking with Hydra configuration management.
|
2
|
+
|
3
|
+
This module provides functions to log parameters from Hydra configuration objects
|
4
|
+
to MLflow, set experiments, and manage tracking URIs. It integrates Hydra's
|
5
|
+
configuration management with MLflow's experiment tracking capabilities.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
from typing import TYPE_CHECKING
|
11
|
+
|
12
|
+
import joblib
|
13
|
+
import mlflow
|
14
|
+
import mlflow.artifacts
|
15
|
+
|
16
|
+
from hydraflow.config import iter_params
|
17
|
+
from hydraflow.run_collection import RunCollection
|
18
|
+
from hydraflow.utils import file_uri_to_path, get_artifact_dir
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from pathlib import Path
|
22
|
+
from typing import Any
|
23
|
+
|
24
|
+
|
25
|
+
def log_params(config: Any, *, synchronous: bool | None = None) -> None:
|
26
|
+
"""Log the parameters from the given configuration object.
|
27
|
+
|
28
|
+
This method logs the parameters from the provided configuration object
|
29
|
+
using MLflow. It iterates over the parameters and logs them using the
|
30
|
+
`mlflow.log_param` method.
|
31
|
+
|
32
|
+
Args:
|
33
|
+
config (Any): The configuration object to log the parameters from.
|
34
|
+
synchronous (bool | None): Whether to log the parameters synchronously.
|
35
|
+
Defaults to None.
|
36
|
+
|
37
|
+
"""
|
38
|
+
for key, value in iter_params(config):
|
39
|
+
mlflow.log_param(key, value, synchronous=synchronous)
|
40
|
+
|
41
|
+
|
42
|
+
def log_text(from_dir: Path, pattern: str = "*.log") -> None:
|
43
|
+
"""Log text files in the given directory as artifacts.
|
44
|
+
|
45
|
+
Append the text files to the existing text file in the artifact directory.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
from_dir (Path): The directory to find the logs in.
|
49
|
+
pattern (str): The pattern to match the logs.
|
50
|
+
|
51
|
+
"""
|
52
|
+
artifact_dir = get_artifact_dir()
|
53
|
+
|
54
|
+
for file in from_dir.glob(pattern):
|
55
|
+
if not file.is_file():
|
56
|
+
continue
|
57
|
+
|
58
|
+
file_artifact = artifact_dir / file.name
|
59
|
+
if file_artifact.exists():
|
60
|
+
text = file_artifact.read_text()
|
61
|
+
if not text.endswith("\n"):
|
62
|
+
text += "\n"
|
63
|
+
else:
|
64
|
+
text = ""
|
65
|
+
|
66
|
+
text += file.read_text()
|
67
|
+
mlflow.log_text(text, file.name)
|
68
|
+
|
69
|
+
|
70
|
+
def list_run_paths(
|
71
|
+
experiment_names: str | list[str] | None = None,
|
72
|
+
*other: str,
|
73
|
+
) -> list[Path]:
|
74
|
+
"""List all run paths for the specified experiments.
|
75
|
+
|
76
|
+
This function retrieves all run paths for the given list of experiment names.
|
77
|
+
If no experiment names are provided (None), the function will search all runs
|
78
|
+
for all experiments except the "Default" experiment.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
experiment_names (list[str] | None): List of experiment names to search
|
82
|
+
for runs. If None is provided, the function will search all runs
|
83
|
+
for all experiments except the "Default" experiment.
|
84
|
+
*other (str): The parts of the run directory to join.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
list[Path]: A list of run paths for the specified experiments.
|
88
|
+
|
89
|
+
"""
|
90
|
+
if isinstance(experiment_names, str):
|
91
|
+
experiment_names = [experiment_names]
|
92
|
+
|
93
|
+
elif experiment_names is None:
|
94
|
+
experiments = mlflow.search_experiments()
|
95
|
+
experiment_names = [e.name for e in experiments if e.name != "Default"]
|
96
|
+
|
97
|
+
run_paths: list[Path] = []
|
98
|
+
|
99
|
+
for name in experiment_names:
|
100
|
+
if experiment := mlflow.get_experiment_by_name(name):
|
101
|
+
uri = experiment.artifact_location
|
102
|
+
|
103
|
+
if isinstance(uri, str):
|
104
|
+
path = file_uri_to_path(uri)
|
105
|
+
run_paths.extend(p for p in path.iterdir() if p.is_dir())
|
106
|
+
|
107
|
+
if other:
|
108
|
+
return [p.joinpath(*other) for p in run_paths]
|
109
|
+
|
110
|
+
return run_paths
|
111
|
+
|
112
|
+
|
113
|
+
def list_run_ids(experiment_names: str | list[str] | None = None) -> list[str]:
|
114
|
+
"""List all run IDs for the specified experiments.
|
115
|
+
|
116
|
+
This function retrieves all runs for the given list of experiment names.
|
117
|
+
If no experiment names are provided (None), the function will search all
|
118
|
+
runs for all experiments except the "Default" experiment.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
experiment_names (list[str] | None): List of experiment names to search
|
122
|
+
for runs. If None is provided, the function will search all runs
|
123
|
+
for all experiments except the "Default" experiment.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
list[str]: A list of run IDs for the specified experiments.
|
127
|
+
|
128
|
+
"""
|
129
|
+
return [run_path.stem for run_path in list_run_paths(experiment_names)]
|
130
|
+
|
131
|
+
|
132
|
+
def list_runs(
|
133
|
+
experiment_names: str | list[str] | None = None,
|
134
|
+
n_jobs: int = 0,
|
135
|
+
) -> RunCollection:
|
136
|
+
"""List all runs for the specified experiments.
|
137
|
+
|
138
|
+
This function retrieves all runs for the given list of experiment names.
|
139
|
+
If no experiment names are provided (None), the function will search all runs
|
140
|
+
for all experiments except the "Default" experiment.
|
141
|
+
The function returns the results as a `RunCollection` object.
|
142
|
+
|
143
|
+
Note:
|
144
|
+
The returned runs are sorted by their start time in ascending order.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
experiment_names (list[str] | None): List of experiment names to search
|
148
|
+
for runs. If None is provided, the function will search all runs
|
149
|
+
for all experiments except the "Default" experiment.
|
150
|
+
n_jobs (int): The number of jobs to retrieve runs in parallel.
|
151
|
+
|
152
|
+
Returns:
|
153
|
+
RunCollection: A `RunCollection` instance containing the runs for the
|
154
|
+
specified experiments.
|
155
|
+
|
156
|
+
"""
|
157
|
+
run_ids = list_run_ids(experiment_names)
|
158
|
+
|
159
|
+
if n_jobs == 0:
|
160
|
+
runs = [mlflow.get_run(run_id) for run_id in run_ids]
|
161
|
+
|
162
|
+
else:
|
163
|
+
it = (joblib.delayed(mlflow.get_run)(run_id) for run_id in run_ids)
|
164
|
+
runs = joblib.Parallel(n_jobs, backend="threading")(it)
|
165
|
+
|
166
|
+
runs = sorted(runs, key=lambda run: run.info.start_time) # type: ignore
|
167
|
+
return RunCollection(runs) # type: ignore
|