hydraflow 0.7.5__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydraflow/__init__.py +5 -20
- hydraflow/cli.py +31 -39
- hydraflow/core/__init__.py +0 -0
- hydraflow/{config.py → core/config.py} +10 -27
- hydraflow/{context.py → core/context.py} +8 -50
- hydraflow/{utils.py → core/io.py} +19 -28
- hydraflow/core/main.py +164 -0
- hydraflow/core/mlflow.py +168 -0
- hydraflow/{param.py → core/param.py} +2 -2
- hydraflow/entities/__init__.py +0 -0
- hydraflow/{run_collection.py → entities/run_collection.py} +18 -163
- hydraflow/{run_data.py → entities/run_data.py} +5 -3
- hydraflow/{run_info.py → entities/run_info.py} +2 -2
- hydraflow/executor/__init__.py +0 -0
- hydraflow/executor/conf.py +23 -0
- hydraflow/executor/io.py +34 -0
- hydraflow/executor/job.py +152 -0
- hydraflow/executor/parser.py +397 -0
- {hydraflow-0.7.5.dist-info → hydraflow-0.9.0.dist-info}/METADATA +18 -19
- hydraflow-0.9.0.dist-info/RECORD +24 -0
- hydraflow/main.py +0 -54
- hydraflow/mlflow.py +0 -280
- hydraflow-0.7.5.dist-info/RECORD +0 -17
- {hydraflow-0.7.5.dist-info → hydraflow-0.9.0.dist-info}/WHEEL +0 -0
- {hydraflow-0.7.5.dist-info → hydraflow-0.9.0.dist-info}/entry_points.txt +0 -0
- {hydraflow-0.7.5.dist-info → hydraflow-0.9.0.dist-info}/licenses/LICENSE +0 -0
hydraflow/__init__.py
CHANGED
@@ -1,25 +1,16 @@
|
|
1
1
|
"""Integrate Hydra and MLflow to manage and track machine learning experiments."""
|
2
2
|
|
3
|
-
from hydraflow.
|
4
|
-
from hydraflow.
|
5
|
-
from hydraflow.main import main
|
6
|
-
from hydraflow.mlflow import (
|
7
|
-
list_run_ids,
|
8
|
-
list_run_paths,
|
9
|
-
list_runs,
|
10
|
-
search_runs,
|
11
|
-
set_experiment,
|
12
|
-
)
|
13
|
-
from hydraflow.run_collection import RunCollection
|
14
|
-
from hydraflow.utils import (
|
3
|
+
from hydraflow.core.context import chdir_artifact, log_run, start_run
|
4
|
+
from hydraflow.core.io import (
|
15
5
|
get_artifact_dir,
|
16
6
|
get_artifact_path,
|
17
7
|
get_hydra_output_dir,
|
18
|
-
get_overrides,
|
19
8
|
load_config,
|
20
|
-
load_overrides,
|
21
9
|
remove_run,
|
22
10
|
)
|
11
|
+
from hydraflow.core.main import main
|
12
|
+
from hydraflow.core.mlflow import list_run_ids, list_run_paths, list_runs
|
13
|
+
from hydraflow.entities.run_collection import RunCollection
|
23
14
|
|
24
15
|
__all__ = [
|
25
16
|
"RunCollection",
|
@@ -27,18 +18,12 @@ __all__ = [
|
|
27
18
|
"get_artifact_dir",
|
28
19
|
"get_artifact_path",
|
29
20
|
"get_hydra_output_dir",
|
30
|
-
"get_overrides",
|
31
21
|
"list_run_ids",
|
32
22
|
"list_run_paths",
|
33
23
|
"list_runs",
|
34
24
|
"load_config",
|
35
|
-
"load_overrides",
|
36
25
|
"log_run",
|
37
26
|
"main",
|
38
27
|
"remove_run",
|
39
|
-
"search_runs",
|
40
|
-
"select_config",
|
41
|
-
"select_overrides",
|
42
|
-
"set_experiment",
|
43
28
|
"start_run",
|
44
29
|
]
|
hydraflow/cli.py
CHANGED
@@ -2,41 +2,54 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
|
5
|
-
from
|
6
|
-
from typing import Annotated
|
5
|
+
from typing import TYPE_CHECKING, Annotated
|
7
6
|
|
8
7
|
import typer
|
9
|
-
from omegaconf import DictConfig, OmegaConf
|
10
8
|
from rich.console import Console
|
11
9
|
from typer import Argument, Option
|
12
10
|
|
11
|
+
from hydraflow.executor.io import load_config
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from hydraflow.executor.job import Job
|
15
|
+
|
13
16
|
app = typer.Typer(add_completion=False)
|
14
17
|
console = Console()
|
15
18
|
|
16
19
|
|
20
|
+
def get_job(name: str) -> Job:
|
21
|
+
cfg = load_config()
|
22
|
+
job = cfg.jobs[name]
|
23
|
+
|
24
|
+
if not job.name:
|
25
|
+
job.name = name
|
26
|
+
|
27
|
+
return job
|
28
|
+
|
29
|
+
|
17
30
|
@app.command()
|
18
31
|
def run(
|
19
|
-
|
20
|
-
list[str] | None,
|
21
|
-
Argument(help="Job names.", show_default=False),
|
22
|
-
] = None,
|
32
|
+
name: Annotated[str, Argument(help="Job name.", show_default=False)],
|
23
33
|
) -> None:
|
24
|
-
"""Run
|
25
|
-
|
34
|
+
"""Run a job."""
|
35
|
+
import mlflow
|
26
36
|
|
27
|
-
|
28
|
-
|
37
|
+
from hydraflow.executor.job import multirun
|
38
|
+
|
39
|
+
job = get_job(name)
|
40
|
+
mlflow.set_experiment(job.name)
|
41
|
+
multirun(job)
|
29
42
|
|
30
43
|
|
31
44
|
@app.command()
|
32
|
-
def show(
|
33
|
-
"
|
34
|
-
|
45
|
+
def show(
|
46
|
+
name: Annotated[str, Argument(help="Job name.", show_default=False)],
|
47
|
+
) -> None:
|
48
|
+
"""Show a job."""
|
49
|
+
from hydraflow.executor.job import show
|
35
50
|
|
36
|
-
|
37
|
-
|
38
|
-
syntax = Syntax(code, "yaml")
|
39
|
-
console.print(syntax)
|
51
|
+
job = get_job(name)
|
52
|
+
show(job)
|
40
53
|
|
41
54
|
|
42
55
|
@app.callback(invoke_without_command=True)
|
@@ -52,24 +65,3 @@ def callback(
|
|
52
65
|
|
53
66
|
typer.echo(f"hydraflow {importlib.metadata.version('hydraflow')}")
|
54
67
|
raise typer.Exit
|
55
|
-
|
56
|
-
|
57
|
-
def find_config() -> Path:
|
58
|
-
if Path("hydraflow.yaml").exists():
|
59
|
-
return Path("hydraflow.yaml")
|
60
|
-
|
61
|
-
if Path("hydraflow.yml").exists():
|
62
|
-
return Path("hydraflow.yml")
|
63
|
-
|
64
|
-
typer.echo("No config file found.")
|
65
|
-
raise typer.Exit(code=1)
|
66
|
-
|
67
|
-
|
68
|
-
def load_config() -> DictConfig:
|
69
|
-
cfg = OmegaConf.load(find_config())
|
70
|
-
|
71
|
-
if isinstance(cfg, DictConfig):
|
72
|
-
return cfg
|
73
|
-
|
74
|
-
typer.echo("Invalid config file.")
|
75
|
-
raise typer.Exit(code=1)
|
File without changes
|
@@ -6,35 +6,19 @@ from typing import TYPE_CHECKING
|
|
6
6
|
|
7
7
|
from omegaconf import DictConfig, ListConfig, OmegaConf
|
8
8
|
|
9
|
-
from hydraflow.utils import get_overrides
|
10
|
-
|
11
9
|
if TYPE_CHECKING:
|
12
10
|
from collections.abc import Iterator
|
13
11
|
from typing import Any
|
14
12
|
|
15
13
|
|
16
|
-
def
|
17
|
-
"""Iterate over parameters and collect them into a dictionary.
|
18
|
-
|
19
|
-
Args:
|
20
|
-
config (object): The configuration object to iterate over.
|
21
|
-
prefix (str): The prefix to prepend to the parameter keys.
|
22
|
-
|
23
|
-
Returns:
|
24
|
-
dict[str, Any]: A dictionary of collected parameters.
|
25
|
-
|
26
|
-
"""
|
27
|
-
return dict(iter_params(config))
|
28
|
-
|
29
|
-
|
30
|
-
def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
14
|
+
def iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
31
15
|
"""Recursively iterate over the parameters in the given configuration object.
|
32
16
|
|
33
17
|
This function traverses the configuration object and yields key-value pairs
|
34
18
|
representing the parameters. The keys are prefixed with the provided prefix.
|
35
19
|
|
36
20
|
Args:
|
37
|
-
config (
|
21
|
+
config (Any): The configuration object to iterate over. This can be a
|
38
22
|
dictionary, list, DictConfig, or ListConfig.
|
39
23
|
prefix (str): The prefix to prepend to the parameter keys.
|
40
24
|
Defaults to an empty string.
|
@@ -50,7 +34,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
50
34
|
config = _from_dotlist(config)
|
51
35
|
|
52
36
|
if not isinstance(config, DictConfig | ListConfig):
|
53
|
-
config = OmegaConf.create(config)
|
37
|
+
config = OmegaConf.create(config)
|
54
38
|
|
55
39
|
yield from _iter_params(config, prefix)
|
56
40
|
|
@@ -65,7 +49,7 @@ def _from_dotlist(config: list[str]) -> dict[str, str]:
|
|
65
49
|
return result
|
66
50
|
|
67
51
|
|
68
|
-
def _iter_params(config:
|
52
|
+
def _iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
69
53
|
if isinstance(config, DictConfig):
|
70
54
|
for key, value in config.items():
|
71
55
|
if _is_param(value):
|
@@ -83,12 +67,12 @@ def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
83
67
|
yield from _iter_params(value, f"{prefix}{index}.")
|
84
68
|
|
85
69
|
|
86
|
-
def _is_param(value:
|
70
|
+
def _is_param(value: Any) -> bool:
|
87
71
|
"""Check if the given value is a parameter."""
|
88
72
|
if isinstance(value, DictConfig):
|
89
73
|
return False
|
90
74
|
|
91
|
-
if isinstance(value, ListConfig):
|
75
|
+
if isinstance(value, ListConfig):
|
92
76
|
if any(isinstance(v, DictConfig | ListConfig) for v in value):
|
93
77
|
return False
|
94
78
|
|
@@ -103,14 +87,14 @@ def _convert(value: Any) -> Any:
|
|
103
87
|
return value
|
104
88
|
|
105
89
|
|
106
|
-
def select_config(config:
|
90
|
+
def select_config(config: Any, names: list[str]) -> dict[str, Any]:
|
107
91
|
"""Select the given parameters from the configuration object.
|
108
92
|
|
109
93
|
This function selects the given parameters from the configuration object
|
110
94
|
and returns a new configuration object containing only the selected parameters.
|
111
95
|
|
112
96
|
Args:
|
113
|
-
config (
|
97
|
+
config (Any): The configuration object to select parameters from.
|
114
98
|
names (list[str]): The names of the parameters to select.
|
115
99
|
|
116
100
|
Returns:
|
@@ -120,7 +104,7 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
|
|
120
104
|
if not isinstance(config, DictConfig):
|
121
105
|
config = OmegaConf.structured(config)
|
122
106
|
|
123
|
-
return {name: _get(config, name) for name in names}
|
107
|
+
return {name: _get(config, name) for name in names}
|
124
108
|
|
125
109
|
|
126
110
|
def _get(config: DictConfig, name: str) -> Any:
|
@@ -132,8 +116,7 @@ def _get(config: DictConfig, name: str) -> Any:
|
|
132
116
|
return _get(config.get(prefix), name)
|
133
117
|
|
134
118
|
|
135
|
-
def select_overrides(config: object) -> dict[str, Any]:
|
119
|
+
def select_overrides(config: object, overrides: list[str]) -> dict[str, Any]:
|
136
120
|
"""Select the given overrides from the configuration object."""
|
137
|
-
overrides = get_overrides()
|
138
121
|
names = [override.split("=")[0].strip() for override in overrides]
|
139
122
|
return select_config(config, names)
|
@@ -12,8 +12,9 @@ import mlflow
|
|
12
12
|
import mlflow.artifacts
|
13
13
|
from hydra.core.hydra_config import HydraConfig
|
14
14
|
|
15
|
-
from hydraflow.
|
16
|
-
|
15
|
+
from hydraflow.core.io import get_artifact_dir
|
16
|
+
|
17
|
+
from .mlflow import log_params, log_text
|
17
18
|
|
18
19
|
if TYPE_CHECKING:
|
19
20
|
from collections.abc import Iterator
|
@@ -55,11 +56,11 @@ def log_run(
|
|
55
56
|
log_params(config, synchronous=synchronous)
|
56
57
|
|
57
58
|
hc = HydraConfig.get()
|
58
|
-
|
59
|
+
hydra_dir = Path(hc.runtime.output_dir)
|
59
60
|
|
60
61
|
# Save '.hydra' config directory.
|
61
|
-
|
62
|
-
mlflow.log_artifacts(
|
62
|
+
hydra_subdir = hydra_dir / (hc.output_subdir or "")
|
63
|
+
mlflow.log_artifacts(hydra_subdir.as_posix(), hc.output_subdir)
|
63
64
|
|
64
65
|
try:
|
65
66
|
yield
|
@@ -70,43 +71,14 @@ def log_run(
|
|
70
71
|
raise
|
71
72
|
|
72
73
|
finally:
|
73
|
-
log_text(
|
74
|
-
|
75
|
-
|
76
|
-
def log_text(directory: Path, pattern: str = "*.log") -> None:
|
77
|
-
"""Log text files in the given directory as artifacts.
|
78
|
-
|
79
|
-
Append the text files to the existing text file in the artifact directory.
|
80
|
-
|
81
|
-
Args:
|
82
|
-
directory (Path): The directory to find the logs in.
|
83
|
-
pattern (str): The pattern to match the logs.
|
84
|
-
|
85
|
-
"""
|
86
|
-
artifact_dir = get_artifact_dir()
|
87
|
-
|
88
|
-
for file in directory.glob(pattern):
|
89
|
-
if not file.is_file():
|
90
|
-
continue
|
91
|
-
|
92
|
-
file_artifact = artifact_dir / file.name
|
93
|
-
if file_artifact.exists():
|
94
|
-
text = file_artifact.read_text()
|
95
|
-
if not text.endswith("\n"):
|
96
|
-
text += "\n"
|
97
|
-
else:
|
98
|
-
text = ""
|
99
|
-
|
100
|
-
text += file.read_text()
|
101
|
-
mlflow.log_text(text, file.name)
|
74
|
+
log_text(hydra_dir)
|
102
75
|
|
103
76
|
|
104
77
|
@contextmanager
|
105
|
-
def start_run(
|
78
|
+
def start_run(
|
106
79
|
config: object,
|
107
80
|
*,
|
108
81
|
chdir: bool = False,
|
109
|
-
run: Run | None = None,
|
110
82
|
run_id: str | None = None,
|
111
83
|
experiment_id: str | None = None,
|
112
84
|
run_name: str | None = None,
|
@@ -126,7 +98,6 @@ def start_run( # noqa: PLR0913
|
|
126
98
|
config (object): The configuration object to log parameters from.
|
127
99
|
chdir (bool): Whether to change the current working directory to the
|
128
100
|
artifact directory of the current run. Defaults to False.
|
129
|
-
run (Run | None): The existing run. Defaults to None.
|
130
101
|
run_id (str | None): The existing run ID. Defaults to None.
|
131
102
|
experiment_id (str | None): The experiment ID. Defaults to None.
|
132
103
|
run_name (str | None): The name of the run. Defaults to None.
|
@@ -142,20 +113,7 @@ def start_run( # noqa: PLR0913
|
|
142
113
|
Yields:
|
143
114
|
Run: An MLflow Run object representing the started run.
|
144
115
|
|
145
|
-
Example:
|
146
|
-
with start_run(config) as run:
|
147
|
-
# Perform operations within the MLflow run context
|
148
|
-
pass
|
149
|
-
|
150
|
-
See Also:
|
151
|
-
- `mlflow.start_run`: The MLflow function to start a run directly.
|
152
|
-
- `log_run`: A context manager to log parameters and manage the MLflow
|
153
|
-
run context.
|
154
|
-
|
155
116
|
"""
|
156
|
-
if run:
|
157
|
-
run_id = run.info.run_id
|
158
|
-
|
159
117
|
with (
|
160
118
|
mlflow.start_run(
|
161
119
|
run_id=run_id,
|
@@ -12,46 +12,42 @@ import mlflow
|
|
12
12
|
import mlflow.artifacts
|
13
13
|
from hydra.core.hydra_config import HydraConfig
|
14
14
|
from mlflow.entities import Run
|
15
|
-
from omegaconf import DictConfig, OmegaConf
|
15
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from collections.abc import Iterable
|
19
19
|
|
20
20
|
|
21
|
-
def
|
21
|
+
def file_uri_to_path(uri: str) -> Path:
|
22
|
+
"""Convert a file URI to a local path."""
|
23
|
+
if not uri.startswith("file:"):
|
24
|
+
return Path(uri)
|
25
|
+
|
26
|
+
path = urllib.parse.urlparse(uri).path
|
27
|
+
return Path(urllib.request.url2pathname(path)) # for Windows
|
28
|
+
|
29
|
+
|
30
|
+
def get_artifact_dir(run: Run | None = None) -> Path:
|
22
31
|
"""Retrieve the artifact directory for the given run.
|
23
32
|
|
24
33
|
This function uses MLflow to get the artifact directory for the given run.
|
25
34
|
|
26
35
|
Args:
|
27
36
|
run (Run | None): The run object. Defaults to None.
|
28
|
-
uri (str | None): The URI of the artifact. Defaults to None.
|
29
37
|
|
30
38
|
Returns:
|
31
39
|
The local path to the directory where the artifacts are downloaded.
|
32
40
|
|
33
41
|
"""
|
34
|
-
if run is
|
35
|
-
raise ValueError("Cannot provide both run and uri")
|
36
|
-
|
37
|
-
if run is None and uri is None:
|
42
|
+
if run is None:
|
38
43
|
uri = mlflow.get_artifact_uri()
|
39
|
-
|
44
|
+
else:
|
40
45
|
uri = run.info.artifact_uri
|
41
46
|
|
42
47
|
if not isinstance(uri, str):
|
43
48
|
raise NotImplementedError
|
44
49
|
|
45
|
-
|
46
|
-
return file_uri_to_path(uri)
|
47
|
-
|
48
|
-
return Path(uri)
|
49
|
-
|
50
|
-
|
51
|
-
def file_uri_to_path(uri: str) -> Path:
|
52
|
-
"""Convert a file URI to a local path."""
|
53
|
-
path = urllib.parse.urlparse(uri).path
|
54
|
-
return Path(urllib.request.url2pathname(path)) # for Windows
|
50
|
+
return file_uri_to_path(uri)
|
55
51
|
|
56
52
|
|
57
53
|
def get_artifact_path(run: Run | None, path: str) -> Path:
|
@@ -123,12 +119,7 @@ def load_config(run: Run) -> DictConfig:
|
|
123
119
|
return OmegaConf.load(path) # type: ignore
|
124
120
|
|
125
121
|
|
126
|
-
def
|
127
|
-
"""Retrieve the overrides for the current run."""
|
128
|
-
return list(HydraConfig.get().overrides.task) # ListConifg -> list
|
129
|
-
|
130
|
-
|
131
|
-
def load_overrides(run: Run) -> list[str]:
|
122
|
+
def load_overrides(run: Run) -> ListConfig:
|
132
123
|
"""Load the overrides for a given run.
|
133
124
|
|
134
125
|
This function loads the overrides for the provided Run instance
|
@@ -137,15 +128,15 @@ def load_overrides(run: Run) -> list[str]:
|
|
137
128
|
`.hydra/overrides.yaml` is not found in the run's artifact directory.
|
138
129
|
|
139
130
|
Args:
|
140
|
-
run (Run): The Run instance for which to load the
|
131
|
+
run (Run): The Run instance for which to load the configuration.
|
141
132
|
|
142
133
|
Returns:
|
143
|
-
The loaded
|
144
|
-
if the
|
134
|
+
The loaded configuration as a DictConfig object. Returns an empty
|
135
|
+
DictConfig if the configuration file is not found.
|
145
136
|
|
146
137
|
"""
|
147
138
|
path = get_artifact_dir(run) / ".hydra/overrides.yaml"
|
148
|
-
return
|
139
|
+
return OmegaConf.load(path) # type: ignore
|
149
140
|
|
150
141
|
|
151
142
|
def remove_run(run: Run | Iterable[Run]) -> None:
|
hydraflow/core/main.py
ADDED
@@ -0,0 +1,164 @@
|
|
1
|
+
"""Integration of MLflow experiment tracking with Hydra configuration management.
|
2
|
+
|
3
|
+
This module provides decorators and utilities to seamlessly combine Hydra's
|
4
|
+
configuration management with MLflow's experiment tracking capabilities. It
|
5
|
+
enables automatic run deduplication, configuration storage, and experiment
|
6
|
+
management.
|
7
|
+
|
8
|
+
The main functionality is provided through the `main` decorator, which can be
|
9
|
+
used to wrap experiment entry points. This decorator handles:
|
10
|
+
|
11
|
+
- Configuration management via Hydra
|
12
|
+
- Experiment tracking via MLflow
|
13
|
+
- Run deduplication based on configurations
|
14
|
+
- Working directory management
|
15
|
+
- Automatic configuration storage
|
16
|
+
|
17
|
+
Example:
|
18
|
+
```python
|
19
|
+
from dataclasses import dataclass
|
20
|
+
from mlflow.entities import Run
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class Config:
|
24
|
+
learning_rate: float
|
25
|
+
batch_size: int
|
26
|
+
|
27
|
+
@main(Config)
|
28
|
+
def train(run: Run, config: Config):
|
29
|
+
# Your training code here
|
30
|
+
pass
|
31
|
+
```
|
32
|
+
|
33
|
+
"""
|
34
|
+
|
35
|
+
from __future__ import annotations
|
36
|
+
|
37
|
+
from functools import wraps
|
38
|
+
from typing import TYPE_CHECKING, TypeVar
|
39
|
+
|
40
|
+
import hydra
|
41
|
+
import mlflow
|
42
|
+
from hydra.core.config_store import ConfigStore
|
43
|
+
from hydra.core.hydra_config import HydraConfig
|
44
|
+
from mlflow.entities import RunStatus
|
45
|
+
from omegaconf import OmegaConf
|
46
|
+
|
47
|
+
import hydraflow
|
48
|
+
from hydraflow.core.io import file_uri_to_path
|
49
|
+
|
50
|
+
if TYPE_CHECKING:
|
51
|
+
from collections.abc import Callable
|
52
|
+
from pathlib import Path
|
53
|
+
from typing import Any
|
54
|
+
|
55
|
+
from mlflow.entities import Run
|
56
|
+
|
57
|
+
FINISHED = RunStatus.to_string(RunStatus.FINISHED)
|
58
|
+
|
59
|
+
T = TypeVar("T")
|
60
|
+
|
61
|
+
|
62
|
+
def main(
|
63
|
+
node: T | type[T],
|
64
|
+
config_name: str = "config",
|
65
|
+
*,
|
66
|
+
chdir: bool = False,
|
67
|
+
force_new_run: bool = False,
|
68
|
+
match_overrides: bool = False,
|
69
|
+
rerun_finished: bool = False,
|
70
|
+
):
|
71
|
+
"""Decorator for configuring and running MLflow experiments with Hydra.
|
72
|
+
|
73
|
+
This decorator combines Hydra configuration management with MLflow experiment
|
74
|
+
tracking. It automatically handles run deduplication and configuration storage.
|
75
|
+
|
76
|
+
Args:
|
77
|
+
node: Configuration node class or instance defining the structure of the
|
78
|
+
configuration.
|
79
|
+
config_name: Name of the configuration. Defaults to "config".
|
80
|
+
chdir: If True, changes working directory to the artifact directory
|
81
|
+
of the run. Defaults to False.
|
82
|
+
force_new_run: If True, always creates a new MLflow run instead of
|
83
|
+
reusing existing ones. Defaults to False.
|
84
|
+
match_overrides: If True, matches runs based on Hydra CLI overrides
|
85
|
+
instead of full config. Defaults to False.
|
86
|
+
rerun_finished: If True, allows rerunning completed runs. Defaults to
|
87
|
+
False.
|
88
|
+
|
89
|
+
"""
|
90
|
+
|
91
|
+
def decorator(app: Callable[[Run, T], None]) -> Callable[[], None]:
|
92
|
+
ConfigStore.instance().store(config_name, node)
|
93
|
+
|
94
|
+
@hydra.main(config_name=config_name, version_base=None)
|
95
|
+
@wraps(app)
|
96
|
+
def inner_decorator(config: T) -> None:
|
97
|
+
hc = HydraConfig.get()
|
98
|
+
experiment = mlflow.set_experiment(hc.job.name)
|
99
|
+
|
100
|
+
if force_new_run:
|
101
|
+
run_id = None
|
102
|
+
else:
|
103
|
+
uri = experiment.artifact_location
|
104
|
+
overrides = hc.overrides.task if match_overrides else None
|
105
|
+
run_id = get_run_id(uri, config, overrides)
|
106
|
+
|
107
|
+
if run_id and not rerun_finished:
|
108
|
+
run = mlflow.get_run(run_id)
|
109
|
+
if run.info.status == FINISHED:
|
110
|
+
return
|
111
|
+
|
112
|
+
with hydraflow.start_run(config, run_id=run_id, chdir=chdir) as run:
|
113
|
+
app(run, config)
|
114
|
+
|
115
|
+
return inner_decorator
|
116
|
+
|
117
|
+
return decorator
|
118
|
+
|
119
|
+
|
120
|
+
def get_run_id(uri: str, config: Any, overrides: list[str] | None) -> str | None:
|
121
|
+
"""Try to get the run ID for the given configuration.
|
122
|
+
|
123
|
+
If the run is not found, the function will return None.
|
124
|
+
|
125
|
+
Args:
|
126
|
+
uri (str): The URI of the experiment.
|
127
|
+
config (object): The configuration object.
|
128
|
+
overrides (list[str] | None): The task overrides.
|
129
|
+
|
130
|
+
Returns:
|
131
|
+
The run ID for the given configuration or overrides. Returns None if
|
132
|
+
no run ID is found.
|
133
|
+
|
134
|
+
"""
|
135
|
+
for run_dir in file_uri_to_path(uri).iterdir():
|
136
|
+
if run_dir.is_dir() and equals(run_dir, config, overrides):
|
137
|
+
return run_dir.name
|
138
|
+
|
139
|
+
return None
|
140
|
+
|
141
|
+
|
142
|
+
def equals(run_dir: Path, config: Any, overrides: list[str] | None) -> bool:
|
143
|
+
"""Check if the run directory matches the given configuration or overrides.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
run_dir (Path): The run directory.
|
147
|
+
config (object): The configuration object.
|
148
|
+
overrides (list[str] | None): The task overrides.
|
149
|
+
|
150
|
+
Returns:
|
151
|
+
True if the run directory matches the given configuration or overrides,
|
152
|
+
False otherwise.
|
153
|
+
|
154
|
+
"""
|
155
|
+
if overrides is None:
|
156
|
+
path = run_dir / "artifacts/.hydra/config.yaml"
|
157
|
+
else:
|
158
|
+
path = run_dir / "artifacts/.hydra/overrides.yaml"
|
159
|
+
config = overrides
|
160
|
+
|
161
|
+
if not path.exists():
|
162
|
+
return False
|
163
|
+
|
164
|
+
return OmegaConf.load(path) == config
|