hydraflow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hydraflow/__init__.py +7 -13
- hydraflow/asyncio.py +199 -0
- hydraflow/config.py +3 -3
- hydraflow/context.py +72 -21
- hydraflow/mlflow.py +64 -12
- hydraflow/runs.py +428 -190
- {hydraflow-0.2.2.dist-info → hydraflow-0.2.4.dist-info}/METADATA +5 -6
- hydraflow-0.2.4.dist-info/RECORD +10 -0
- hydraflow-0.2.2.dist-info/RECORD +0 -9
- {hydraflow-0.2.2.dist-info → hydraflow-0.2.4.dist-info}/WHEEL +0 -0
- {hydraflow-0.2.2.dist-info → hydraflow-0.2.4.dist-info}/licenses/LICENSE +0 -0
hydraflow/__init__.py
CHANGED
@@ -1,28 +1,22 @@
|
|
1
|
-
from .context import
|
2
|
-
from .mlflow import set_experiment
|
1
|
+
from .context import chdir_artifact, log_run, start_run, watch
|
2
|
+
from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
|
3
3
|
from .runs import (
|
4
|
-
Run,
|
5
4
|
RunCollection,
|
6
|
-
|
7
|
-
get_param_dict,
|
8
|
-
get_param_names,
|
9
|
-
get_run,
|
5
|
+
list_runs,
|
10
6
|
load_config,
|
11
7
|
search_runs,
|
12
8
|
)
|
13
9
|
|
14
10
|
__all__ = [
|
15
|
-
"Info",
|
16
|
-
"Run",
|
17
11
|
"RunCollection",
|
18
12
|
"chdir_artifact",
|
19
|
-
"
|
20
|
-
"
|
21
|
-
"
|
22
|
-
"get_run",
|
13
|
+
"get_artifact_dir",
|
14
|
+
"get_hydra_output_dir",
|
15
|
+
"list_runs",
|
23
16
|
"load_config",
|
24
17
|
"log_run",
|
25
18
|
"search_runs",
|
26
19
|
"set_experiment",
|
20
|
+
"start_run",
|
27
21
|
"watch",
|
28
22
|
]
|
hydraflow/asyncio.py
ADDED
@@ -0,0 +1,199 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
from asyncio.subprocess import PIPE
|
6
|
+
from pathlib import Path
|
7
|
+
from typing import TYPE_CHECKING
|
8
|
+
|
9
|
+
import watchfiles
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from asyncio.streams import StreamReader
|
13
|
+
from collections.abc import Callable
|
14
|
+
|
15
|
+
from watchfiles import Change
|
16
|
+
|
17
|
+
|
18
|
+
# Set up logging
|
19
|
+
logging.basicConfig(level=logging.INFO)
|
20
|
+
logger = logging.getLogger(__name__)
|
21
|
+
|
22
|
+
|
23
|
+
async def execute_command(
|
24
|
+
program: str,
|
25
|
+
*args: str,
|
26
|
+
stdout: Callable[[str], None] | None = None,
|
27
|
+
stderr: Callable[[str], None] | None = None,
|
28
|
+
stop_event: asyncio.Event,
|
29
|
+
) -> int:
|
30
|
+
"""
|
31
|
+
Runs a command asynchronously and pass the output to callback functions.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
program (str): The program to run.
|
35
|
+
*args (str): Arguments for the program.
|
36
|
+
stdout (Callable[[str], None] | None): Callback for standard output.
|
37
|
+
stderr (Callable[[str], None] | None): Callback for standard error.
|
38
|
+
stop_event (asyncio.Event): Event to signal when the process is done.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
int: The return code of the process.
|
42
|
+
"""
|
43
|
+
try:
|
44
|
+
process = await asyncio.create_subprocess_exec(program, *args, stdout=PIPE, stderr=PIPE)
|
45
|
+
await asyncio.gather(
|
46
|
+
process_stream(process.stdout, stdout),
|
47
|
+
process_stream(process.stderr, stderr),
|
48
|
+
)
|
49
|
+
returncode = await process.wait()
|
50
|
+
|
51
|
+
except Exception as e:
|
52
|
+
logger.error(f"Error running command: {e}")
|
53
|
+
returncode = 1
|
54
|
+
|
55
|
+
finally:
|
56
|
+
stop_event.set()
|
57
|
+
|
58
|
+
return returncode
|
59
|
+
|
60
|
+
|
61
|
+
async def process_stream(
|
62
|
+
stream: StreamReader | None,
|
63
|
+
callback: Callable[[str], None] | None,
|
64
|
+
) -> None:
|
65
|
+
"""
|
66
|
+
Reads a stream asynchronously and pass each line to a callback function.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
stream (StreamReader | None): The stream to read from.
|
70
|
+
callback (Callable[[str], None] | None): The callback function to handle
|
71
|
+
each line.
|
72
|
+
"""
|
73
|
+
if stream is None or callback is None:
|
74
|
+
return
|
75
|
+
|
76
|
+
while True:
|
77
|
+
line = await stream.readline()
|
78
|
+
if line:
|
79
|
+
callback(line.decode().strip())
|
80
|
+
else:
|
81
|
+
break
|
82
|
+
|
83
|
+
|
84
|
+
async def monitor_file_changes(
|
85
|
+
paths: list[str | Path],
|
86
|
+
callback: Callable[[set[tuple[Change, str]]], None],
|
87
|
+
stop_event: asyncio.Event,
|
88
|
+
**awatch_kwargs,
|
89
|
+
) -> None:
|
90
|
+
"""
|
91
|
+
Watches for file changes in specified paths and pass the changes to a
|
92
|
+
callback function.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
paths (list[str | Path]): List of paths to monitor for changes.
|
96
|
+
callback (Callable[[set[tuple[Change, str]]], None]): The callback
|
97
|
+
function to handle file changes.
|
98
|
+
stop_event (asyncio.Event): Event to signal when to stop watching.
|
99
|
+
**awatch_kwargs: Additional keyword arguments to pass to watchfiles.awatch.
|
100
|
+
"""
|
101
|
+
str_paths = [str(path) for path in paths]
|
102
|
+
try:
|
103
|
+
async for changes in watchfiles.awatch(*str_paths, stop_event=stop_event, **awatch_kwargs):
|
104
|
+
callback(changes)
|
105
|
+
except Exception as e:
|
106
|
+
logger.error(f"Error watching files: {e}")
|
107
|
+
|
108
|
+
|
109
|
+
async def run_and_monitor(
|
110
|
+
program: str,
|
111
|
+
*args: str,
|
112
|
+
stdout: Callable[[str], None] | None = None,
|
113
|
+
stderr: Callable[[str], None] | None = None,
|
114
|
+
watch: Callable[[set[tuple[Change, str]]], None] | None = None,
|
115
|
+
paths: list[str | Path] | None = None,
|
116
|
+
**awatch_kwargs,
|
117
|
+
) -> int:
|
118
|
+
"""
|
119
|
+
Runs a command and optionally watch for file changes concurrently.
|
120
|
+
|
121
|
+
Args:
|
122
|
+
program (str): The program to run.
|
123
|
+
*args (str): Arguments for the program.
|
124
|
+
stdout (Callable[[str], None] | None): Callback for standard output.
|
125
|
+
stderr (Callable[[str], None] | None): Callback for standard error.
|
126
|
+
watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for
|
127
|
+
file changes.
|
128
|
+
paths (list[str | Path] | None): List of paths to monitor for changes.
|
129
|
+
"""
|
130
|
+
stop_event = asyncio.Event()
|
131
|
+
run_task = asyncio.create_task(
|
132
|
+
execute_command(program, *args, stop_event=stop_event, stdout=stdout, stderr=stderr)
|
133
|
+
)
|
134
|
+
if watch and paths:
|
135
|
+
monitor_task = asyncio.create_task(
|
136
|
+
monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
monitor_task = None
|
140
|
+
|
141
|
+
try:
|
142
|
+
if monitor_task:
|
143
|
+
await asyncio.gather(run_task, monitor_task)
|
144
|
+
else:
|
145
|
+
await run_task
|
146
|
+
|
147
|
+
except Exception as e:
|
148
|
+
logger.error(f"Error in run_and_monitor: {e}")
|
149
|
+
finally:
|
150
|
+
stop_event.set()
|
151
|
+
await run_task
|
152
|
+
if monitor_task:
|
153
|
+
await monitor_task
|
154
|
+
|
155
|
+
return run_task.result()
|
156
|
+
|
157
|
+
|
158
|
+
def run(
|
159
|
+
program: str,
|
160
|
+
*args: str,
|
161
|
+
stdout: Callable[[str], None] | None = None,
|
162
|
+
stderr: Callable[[str], None] | None = None,
|
163
|
+
watch: Callable[[set[tuple[Change, str]]], None] | None = None,
|
164
|
+
paths: list[str | Path] | None = None,
|
165
|
+
**awatch_kwargs,
|
166
|
+
) -> int:
|
167
|
+
"""
|
168
|
+
Run a command synchronously and optionally watch for file changes.
|
169
|
+
|
170
|
+
This function is a synchronous wrapper around the asynchronous `run_and_monitor` function.
|
171
|
+
It runs a specified command and optionally monitors specified paths for file changes,
|
172
|
+
invoking the provided callbacks for standard output, standard error, and file changes.
|
173
|
+
|
174
|
+
Args:
|
175
|
+
program (str): The program to run.
|
176
|
+
*args (str): Arguments for the program.
|
177
|
+
stdout (Callable[[str], None] | None): Callback for handling standard output lines.
|
178
|
+
stderr (Callable[[str], None] | None): Callback for handling standard error lines.
|
179
|
+
watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for handling file changes.
|
180
|
+
paths (list[str | Path] | None): List of paths to monitor for file changes.
|
181
|
+
**awatch_kwargs: Additional keyword arguments to pass to `watchfiles.awatch`.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
int: The return code of the process.
|
185
|
+
"""
|
186
|
+
if watch and not paths:
|
187
|
+
paths = [Path.cwd()]
|
188
|
+
|
189
|
+
return asyncio.run(
|
190
|
+
run_and_monitor(
|
191
|
+
program,
|
192
|
+
*args,
|
193
|
+
stdout=stdout,
|
194
|
+
stderr=stderr,
|
195
|
+
watch=watch,
|
196
|
+
paths=paths,
|
197
|
+
**awatch_kwargs,
|
198
|
+
)
|
199
|
+
)
|
hydraflow/config.py
CHANGED
@@ -22,9 +22,9 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
22
22
|
representing the parameters. The keys are prefixed with the provided prefix.
|
23
23
|
|
24
24
|
Args:
|
25
|
-
config: The configuration object to iterate over. This can be a
|
26
|
-
list, DictConfig, or ListConfig.
|
27
|
-
prefix: The prefix to prepend to the parameter keys.
|
25
|
+
config (object): The configuration object to iterate over. This can be a
|
26
|
+
dictionary, list, DictConfig, or ListConfig.
|
27
|
+
prefix (str): The prefix to prepend to the parameter keys.
|
28
28
|
Defaults to an empty string.
|
29
29
|
|
30
30
|
Yields:
|
hydraflow/context.py
CHANGED
@@ -9,7 +9,6 @@ import logging
|
|
9
9
|
import os
|
10
10
|
import time
|
11
11
|
from contextlib import contextmanager
|
12
|
-
from dataclasses import dataclass
|
13
12
|
from pathlib import Path
|
14
13
|
from typing import TYPE_CHECKING
|
15
14
|
|
@@ -28,18 +27,12 @@ if TYPE_CHECKING:
|
|
28
27
|
log = logging.getLogger(__name__)
|
29
28
|
|
30
29
|
|
31
|
-
@dataclass
|
32
|
-
class Info:
|
33
|
-
output_dir: Path
|
34
|
-
artifact_dir: Path
|
35
|
-
|
36
|
-
|
37
30
|
@contextmanager
|
38
31
|
def log_run(
|
39
32
|
config: object,
|
40
33
|
*,
|
41
34
|
synchronous: bool | None = None,
|
42
|
-
) -> Iterator[
|
35
|
+
) -> Iterator[None]:
|
43
36
|
"""
|
44
37
|
Log the parameters from the given configuration object and manage the MLflow
|
45
38
|
run context.
|
@@ -49,16 +42,15 @@ def log_run(
|
|
49
42
|
are logged and the run is properly closed.
|
50
43
|
|
51
44
|
Args:
|
52
|
-
config: The configuration object to log the parameters from.
|
53
|
-
synchronous: Whether to log the parameters synchronously.
|
45
|
+
config (object): The configuration object to log the parameters from.
|
46
|
+
synchronous (bool | None): Whether to log the parameters synchronously.
|
54
47
|
Defaults to None.
|
55
48
|
|
56
49
|
Yields:
|
57
|
-
|
58
|
-
paths.
|
50
|
+
None
|
59
51
|
|
60
52
|
Example:
|
61
|
-
with log_run(config)
|
53
|
+
with log_run(config):
|
62
54
|
# Perform operations within the MLflow run context
|
63
55
|
pass
|
64
56
|
"""
|
@@ -66,7 +58,6 @@ def log_run(
|
|
66
58
|
|
67
59
|
hc = HydraConfig.get()
|
68
60
|
output_dir = Path(hc.runtime.output_dir)
|
69
|
-
info = Info(output_dir, get_artifact_dir())
|
70
61
|
|
71
62
|
# Save '.hydra' config directory first.
|
72
63
|
output_subdir = output_dir / (hc.output_subdir or "")
|
@@ -78,7 +69,7 @@ def log_run(
|
|
78
69
|
|
79
70
|
try:
|
80
71
|
with watch(log_artifact, output_dir):
|
81
|
-
yield
|
72
|
+
yield
|
82
73
|
|
83
74
|
except Exception as e:
|
84
75
|
log.error(f"Error during log_run: {e}")
|
@@ -89,6 +80,64 @@ def log_run(
|
|
89
80
|
mlflow.log_artifacts(output_dir.as_posix())
|
90
81
|
|
91
82
|
|
83
|
+
@contextmanager
|
84
|
+
def start_run(
|
85
|
+
config: object,
|
86
|
+
*,
|
87
|
+
run_id: str | None = None,
|
88
|
+
experiment_id: str | None = None,
|
89
|
+
run_name: str | None = None,
|
90
|
+
nested: bool = False,
|
91
|
+
parent_run_id: str | None = None,
|
92
|
+
tags: dict[str, str] | None = None,
|
93
|
+
description: str | None = None,
|
94
|
+
log_system_metrics: bool | None = None,
|
95
|
+
synchronous: bool | None = None,
|
96
|
+
) -> Iterator[Run]:
|
97
|
+
"""
|
98
|
+
Start an MLflow run and log parameters using the provided configuration object.
|
99
|
+
|
100
|
+
This context manager starts an MLflow run and logs parameters using the specified
|
101
|
+
configuration object. It ensures that the run is properly closed after completion.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
config (object): The configuration object to log parameters from.
|
105
|
+
run_id (str | None): The existing run ID. Defaults to None.
|
106
|
+
experiment_id (str | None): The experiment ID. Defaults to None.
|
107
|
+
run_name (str | None): The name of the run. Defaults to None.
|
108
|
+
nested (bool): Whether to allow nested runs. Defaults to False.
|
109
|
+
parent_run_id (str | None): The parent run ID. Defaults to None.
|
110
|
+
tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
|
111
|
+
description (str | None): A description of the run. Defaults to None.
|
112
|
+
log_system_metrics (bool | None): Whether to log system metrics. Defaults to None.
|
113
|
+
synchronous (bool | None): Whether to log parameters synchronously. Defaults to None.
|
114
|
+
|
115
|
+
Yields:
|
116
|
+
Run: An MLflow Run object representing the started run.
|
117
|
+
|
118
|
+
Example:
|
119
|
+
with start_run(config) as run:
|
120
|
+
# Perform operations within the MLflow run context
|
121
|
+
pass
|
122
|
+
|
123
|
+
See Also:
|
124
|
+
`mlflow.start_run`: The MLflow function to start a run directly.
|
125
|
+
`log_run`: A context manager to log parameters and manage the MLflow run context.
|
126
|
+
"""
|
127
|
+
with mlflow.start_run(
|
128
|
+
run_id=run_id,
|
129
|
+
experiment_id=experiment_id,
|
130
|
+
run_name=run_name,
|
131
|
+
nested=nested,
|
132
|
+
parent_run_id=parent_run_id,
|
133
|
+
tags=tags,
|
134
|
+
description=description,
|
135
|
+
log_system_metrics=log_system_metrics,
|
136
|
+
) as run:
|
137
|
+
with log_run(config, synchronous=synchronous):
|
138
|
+
yield run
|
139
|
+
|
140
|
+
|
92
141
|
@contextmanager
|
93
142
|
def watch(
|
94
143
|
func: Callable[[Path], None],
|
@@ -105,12 +154,12 @@ def watch(
|
|
105
154
|
period or until the context is exited.
|
106
155
|
|
107
156
|
Args:
|
108
|
-
func: The function to call when a change is
|
157
|
+
func (Callable[[Path], None]): The function to call when a change is
|
109
158
|
detected. It should accept a single argument of type `Path`,
|
110
159
|
which is the path of the modified file.
|
111
|
-
dir: The directory to watch. If not specified,
|
160
|
+
dir (Path | str): The directory to watch. If not specified,
|
112
161
|
the current MLflow artifact URI is used. Defaults to "".
|
113
|
-
timeout: The timeout period in seconds for the watcher
|
162
|
+
timeout (int): The timeout period in seconds for the watcher
|
114
163
|
to run after the context is exited. Defaults to 60.
|
115
164
|
|
116
165
|
Yields:
|
@@ -122,6 +171,8 @@ def watch(
|
|
122
171
|
pass
|
123
172
|
"""
|
124
173
|
dir = dir or get_artifact_dir()
|
174
|
+
if isinstance(dir, Path):
|
175
|
+
dir = dir.as_posix()
|
125
176
|
|
126
177
|
handler = Handler(func)
|
127
178
|
observer = Observer()
|
@@ -152,7 +203,7 @@ class Handler(FileSystemEventHandler):
|
|
152
203
|
self.func = func
|
153
204
|
|
154
205
|
def on_modified(self, event: FileModifiedEvent) -> None:
|
155
|
-
file = Path(event.src_path)
|
206
|
+
file = Path(str(event.src_path))
|
156
207
|
if file.is_file():
|
157
208
|
self.func(file)
|
158
209
|
|
@@ -171,8 +222,8 @@ def chdir_artifact(
|
|
171
222
|
to the original directory after the context is exited.
|
172
223
|
|
173
224
|
Args:
|
174
|
-
run: The run to get the artifact directory from.
|
175
|
-
artifact_path: The artifact path.
|
225
|
+
run (Run): The run to get the artifact directory from.
|
226
|
+
artifact_path (str | None): The artifact path.
|
176
227
|
"""
|
177
228
|
curdir = Path.cwd()
|
178
229
|
path = mlflow.artifacts.download_artifacts(
|
hydraflow/mlflow.py
CHANGED
@@ -6,14 +6,24 @@ configuration objects and set up experiments using MLflow.
|
|
6
6
|
from __future__ import annotations
|
7
7
|
|
8
8
|
from pathlib import Path
|
9
|
+
from typing import TYPE_CHECKING
|
9
10
|
|
10
11
|
import mlflow
|
11
12
|
from hydra.core.hydra_config import HydraConfig
|
13
|
+
from mlflow.tracking import artifact_utils
|
14
|
+
from omegaconf import OmegaConf
|
12
15
|
|
13
16
|
from hydraflow.config import iter_params
|
14
17
|
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from mlflow.entities.experiment import Experiment
|
15
20
|
|
16
|
-
|
21
|
+
|
22
|
+
def set_experiment(
|
23
|
+
prefix: str = "",
|
24
|
+
suffix: str = "",
|
25
|
+
uri: str | Path | None = None,
|
26
|
+
) -> Experiment:
|
17
27
|
"""
|
18
28
|
Set the experiment name and tracking URI optionally.
|
19
29
|
|
@@ -22,16 +32,20 @@ def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -
|
|
22
32
|
also set the tracking URI.
|
23
33
|
|
24
34
|
Args:
|
25
|
-
prefix: The prefix to prepend to the experiment name.
|
26
|
-
suffix: The suffix to append to the experiment name.
|
27
|
-
uri: The tracking URI to use.
|
35
|
+
prefix (str): The prefix to prepend to the experiment name.
|
36
|
+
suffix (str): The suffix to append to the experiment name.
|
37
|
+
uri (str | Path | None): The tracking URI to use. Defaults to None.
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
Experiment: An instance of `mlflow.entities.Experiment` representing
|
41
|
+
the new active experiment.
|
28
42
|
"""
|
29
|
-
if uri:
|
43
|
+
if uri is not None:
|
30
44
|
mlflow.set_tracking_uri(uri)
|
31
45
|
|
32
46
|
hc = HydraConfig.get()
|
33
47
|
name = f"{prefix}{hc.job.name}{suffix}"
|
34
|
-
mlflow.set_experiment(name)
|
48
|
+
return mlflow.set_experiment(name)
|
35
49
|
|
36
50
|
|
37
51
|
def log_params(config: object, *, synchronous: bool | None = None) -> None:
|
@@ -43,15 +57,19 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
|
|
43
57
|
`mlflow.log_param` method.
|
44
58
|
|
45
59
|
Args:
|
46
|
-
config: The configuration object to log the parameters from.
|
47
|
-
synchronous: Whether to log the parameters synchronously.
|
60
|
+
config (object): The configuration object to log the parameters from.
|
61
|
+
synchronous (bool | None): Whether to log the parameters synchronously.
|
48
62
|
Defaults to None.
|
49
63
|
"""
|
50
64
|
for key, value in iter_params(config):
|
51
65
|
mlflow.log_param(key, value, synchronous=synchronous)
|
52
66
|
|
53
67
|
|
54
|
-
def get_artifact_dir(
|
68
|
+
def get_artifact_dir(
|
69
|
+
artifact_path: str | None = None,
|
70
|
+
*,
|
71
|
+
run_id: str | None = None,
|
72
|
+
) -> Path:
|
55
73
|
"""
|
56
74
|
Get the artifact directory for the given artifact path.
|
57
75
|
|
@@ -60,13 +78,47 @@ def get_artifact_dir(artifact_path: str | None = None) -> Path:
|
|
60
78
|
the path to that directory.
|
61
79
|
|
62
80
|
Args:
|
63
|
-
artifact_path: The artifact path for which to get the
|
64
|
-
Defaults to None.
|
81
|
+
artifact_path (str | None): The artifact path for which to get the
|
82
|
+
directory. Defaults to None.
|
83
|
+
run_id (str | None): The run ID for which to get the artifact directory.
|
65
84
|
|
66
85
|
Returns:
|
67
86
|
The local path to the directory where the artifacts are downloaded.
|
68
87
|
"""
|
69
|
-
|
88
|
+
if run_id is None:
|
89
|
+
uri = mlflow.get_artifact_uri(artifact_path)
|
90
|
+
else:
|
91
|
+
uri = artifact_utils.get_artifact_uri(run_id, artifact_path)
|
92
|
+
|
70
93
|
dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
|
71
94
|
|
72
95
|
return Path(dir)
|
96
|
+
|
97
|
+
|
98
|
+
def get_hydra_output_dir(*, run_id: str | None = None) -> Path:
|
99
|
+
if run_id is None:
|
100
|
+
hc = HydraConfig.get()
|
101
|
+
return Path(hc.runtime.output_dir)
|
102
|
+
|
103
|
+
path = get_artifact_dir(run_id=run_id) / ".hydra/hydra.yaml"
|
104
|
+
|
105
|
+
if path.exists():
|
106
|
+
hc = OmegaConf.load(path)
|
107
|
+
return Path(hc.hydra.runtime.output_dir)
|
108
|
+
|
109
|
+
raise FileNotFoundError
|
110
|
+
|
111
|
+
|
112
|
+
# def log_hydra_output_dir(run: Run_ | Series | str) -> None:
|
113
|
+
# """
|
114
|
+
# Log the Hydra output directory.
|
115
|
+
|
116
|
+
# Args:
|
117
|
+
# run: The run object.
|
118
|
+
|
119
|
+
# Returns:
|
120
|
+
# None
|
121
|
+
# """
|
122
|
+
# output_dir = get_hydra_output_dir(run)
|
123
|
+
# run_id = run if isinstance(run, str) else run.info.run_id
|
124
|
+
# mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
|