hydraflow 0.2.2__tar.gz → 0.2.4__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (30) hide show
  1. {hydraflow-0.2.2 → hydraflow-0.2.4}/.devcontainer/devcontainer.json +1 -2
  2. {hydraflow-0.2.2 → hydraflow-0.2.4}/.gitignore +3 -2
  3. {hydraflow-0.2.2 → hydraflow-0.2.4}/PKG-INFO +5 -6
  4. {hydraflow-0.2.2 → hydraflow-0.2.4}/README.md +2 -5
  5. {hydraflow-0.2.2 → hydraflow-0.2.4}/pyproject.toml +16 -4
  6. hydraflow-0.2.4/src/hydraflow/__init__.py +22 -0
  7. hydraflow-0.2.4/src/hydraflow/asyncio.py +199 -0
  8. {hydraflow-0.2.2 → hydraflow-0.2.4}/src/hydraflow/config.py +3 -3
  9. {hydraflow-0.2.2 → hydraflow-0.2.4}/src/hydraflow/context.py +72 -21
  10. hydraflow-0.2.4/src/hydraflow/mlflow.py +124 -0
  11. hydraflow-0.2.4/src/hydraflow/runs.py +844 -0
  12. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/scripts/log_run.py +7 -5
  13. hydraflow-0.2.4/tests/test_asyncio.py +159 -0
  14. hydraflow-0.2.4/tests/test_context.py +80 -0
  15. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/test_log_run.py +1 -1
  16. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/test_runs.py +129 -5
  17. hydraflow-0.2.2/src/hydraflow/__init__.py +0 -28
  18. hydraflow-0.2.2/src/hydraflow/mlflow.py +0 -72
  19. hydraflow-0.2.2/src/hydraflow/runs.py +0 -606
  20. hydraflow-0.2.2/tests/test_context.py +0 -36
  21. {hydraflow-0.2.2 → hydraflow-0.2.4}/.devcontainer/postCreate.sh +0 -0
  22. {hydraflow-0.2.2 → hydraflow-0.2.4}/.devcontainer/starship.toml +0 -0
  23. {hydraflow-0.2.2 → hydraflow-0.2.4}/.gitattributes +0 -0
  24. {hydraflow-0.2.2 → hydraflow-0.2.4}/LICENSE +0 -0
  25. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/scripts/__init__.py +0 -0
  26. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/scripts/watch.py +0 -0
  27. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/test_config.py +0 -0
  28. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/test_mlflow.py +0 -0
  29. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/test_version.py +0 -0
  30. {hydraflow-0.2.2 → hydraflow-0.2.4}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "hydraflow",
3
- "image": "mcr.microsoft.com/vscode/devcontainers/base:ubuntu-22.04",
3
+ "image": "mcr.microsoft.com/vscode/devcontainers/python:3.12",
4
4
  "features": {
5
5
  "ghcr.io/devcontainers-contrib/features/starship:1": {},
6
6
  "ghcr.io/va-h/devcontainers-features/uv:1": {}
@@ -9,7 +9,6 @@
9
9
  "vscode": {
10
10
  "extensions": [
11
11
  "charliermarsh.ruff",
12
- "henriiik.vscode-sort",
13
12
  "ms-python.python",
14
13
  "ms-python.vscode-pylance"
15
14
  ]
@@ -1,5 +1,6 @@
1
1
  .coverage
2
+ .env
2
3
  .venv/
3
4
  __pycache__/
4
- lcov.info
5
- dist/
5
+ dist/
6
+ lcov.info
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -20,7 +20,9 @@ Requires-Dist: hydra-core>1.3
20
20
  Requires-Dist: mlflow>2.15
21
21
  Requires-Dist: setuptools
22
22
  Requires-Dist: watchdog
23
+ Requires-Dist: watchfiles
23
24
  Provides-Extra: dev
25
+ Requires-Dist: pytest-asyncio; extra == 'dev'
24
26
  Requires-Dist: pytest-clarity; extra == 'dev'
25
27
  Requires-Dist: pytest-cov; extra == 'dev'
26
28
  Requires-Dist: pytest-randomly; extra == 'dev'
@@ -97,13 +99,10 @@ def my_app(cfg: MySQLConfig) -> None:
97
99
  # Set experiment by Hydra job name.
98
100
  hydraflow.set_experiment()
99
101
 
100
- # Automatically log params using Hydra config.
101
- with mlflow.start_run(), hydraflow.log_run(cfg) as info:
102
+ # Automatically log Hydra config as params.
103
+ with hydraflow.start_run():
102
104
  # Your app code below.
103
105
 
104
- # `info.output_dir` is the Hydra output directory.
105
- # `info.artifact_dir` is the MLflow artifact directory.
106
-
107
106
  with hydraflow.watch(callback):
108
107
  # Watch files in the MLflow artifact directory.
109
108
  # You can update metrics or log other artifacts
@@ -68,13 +68,10 @@ def my_app(cfg: MySQLConfig) -> None:
68
68
  # Set experiment by Hydra job name.
69
69
  hydraflow.set_experiment()
70
70
 
71
- # Automatically log params using Hydra config.
72
- with mlflow.start_run(), hydraflow.log_run(cfg) as info:
71
+ # Automatically log Hydra config as params.
72
+ with hydraflow.start_run():
73
73
  # Your app code below.
74
74
 
75
- # `info.output_dir` is the Hydra output directory.
76
- # `info.artifact_dir` is the MLflow artifact directory.
77
-
78
75
  with hydraflow.watch(callback):
79
76
  # Watch files in the MLflow artifact directory.
80
77
  # You can update metrics or log other artifacts
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.2"
7
+ version = "0.2.4"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -19,10 +19,22 @@ classifiers = [
19
19
  "Topic :: Software Development :: Documentation",
20
20
  ]
21
21
  requires-python = ">=3.10"
22
- dependencies = ["hydra-core>1.3", "mlflow>2.15", "setuptools", "watchdog"]
22
+ dependencies = [
23
+ "hydra-core>1.3",
24
+ "mlflow>2.15",
25
+ "setuptools",
26
+ "watchdog",
27
+ "watchfiles",
28
+ ]
23
29
 
24
30
  [project.optional-dependencies]
25
- dev = ["pytest-clarity", "pytest-cov", "pytest-randomly", "pytest-xdist"]
31
+ dev = [
32
+ "pytest-asyncio",
33
+ "pytest-clarity",
34
+ "pytest-cov",
35
+ "pytest-randomly",
36
+ "pytest-xdist",
37
+ ]
26
38
 
27
39
  [project.urls]
28
40
  Documentation = "https://github.com/daizutabi/hydraflow"
@@ -41,9 +53,9 @@ addopts = [
41
53
  "--cov=hydraflow",
42
54
  "--cov-report=lcov:lcov.info",
43
55
  ]
44
-
45
56
  doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"]
46
57
  filterwarnings = ['ignore:pkg_resources is deprecated:DeprecationWarning']
58
+ asyncio_default_fixture_loop_scope = "function"
47
59
 
48
60
  [tool.coverage.report]
49
61
  exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
@@ -0,0 +1,22 @@
1
+ from .context import chdir_artifact, log_run, start_run, watch
2
+ from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
3
+ from .runs import (
4
+ RunCollection,
5
+ list_runs,
6
+ load_config,
7
+ search_runs,
8
+ )
9
+
10
+ __all__ = [
11
+ "RunCollection",
12
+ "chdir_artifact",
13
+ "get_artifact_dir",
14
+ "get_hydra_output_dir",
15
+ "list_runs",
16
+ "load_config",
17
+ "log_run",
18
+ "search_runs",
19
+ "set_experiment",
20
+ "start_run",
21
+ "watch",
22
+ ]
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from asyncio.subprocess import PIPE
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ import watchfiles
10
+
11
+ if TYPE_CHECKING:
12
+ from asyncio.streams import StreamReader
13
+ from collections.abc import Callable
14
+
15
+ from watchfiles import Change
16
+
17
+
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ async def execute_command(
24
+ program: str,
25
+ *args: str,
26
+ stdout: Callable[[str], None] | None = None,
27
+ stderr: Callable[[str], None] | None = None,
28
+ stop_event: asyncio.Event,
29
+ ) -> int:
30
+ """
31
+ Runs a command asynchronously and pass the output to callback functions.
32
+
33
+ Args:
34
+ program (str): The program to run.
35
+ *args (str): Arguments for the program.
36
+ stdout (Callable[[str], None] | None): Callback for standard output.
37
+ stderr (Callable[[str], None] | None): Callback for standard error.
38
+ stop_event (asyncio.Event): Event to signal when the process is done.
39
+
40
+ Returns:
41
+ int: The return code of the process.
42
+ """
43
+ try:
44
+ process = await asyncio.create_subprocess_exec(program, *args, stdout=PIPE, stderr=PIPE)
45
+ await asyncio.gather(
46
+ process_stream(process.stdout, stdout),
47
+ process_stream(process.stderr, stderr),
48
+ )
49
+ returncode = await process.wait()
50
+
51
+ except Exception as e:
52
+ logger.error(f"Error running command: {e}")
53
+ returncode = 1
54
+
55
+ finally:
56
+ stop_event.set()
57
+
58
+ return returncode
59
+
60
+
61
+ async def process_stream(
62
+ stream: StreamReader | None,
63
+ callback: Callable[[str], None] | None,
64
+ ) -> None:
65
+ """
66
+ Reads a stream asynchronously and pass each line to a callback function.
67
+
68
+ Args:
69
+ stream (StreamReader | None): The stream to read from.
70
+ callback (Callable[[str], None] | None): The callback function to handle
71
+ each line.
72
+ """
73
+ if stream is None or callback is None:
74
+ return
75
+
76
+ while True:
77
+ line = await stream.readline()
78
+ if line:
79
+ callback(line.decode().strip())
80
+ else:
81
+ break
82
+
83
+
84
+ async def monitor_file_changes(
85
+ paths: list[str | Path],
86
+ callback: Callable[[set[tuple[Change, str]]], None],
87
+ stop_event: asyncio.Event,
88
+ **awatch_kwargs,
89
+ ) -> None:
90
+ """
91
+ Watches for file changes in specified paths and pass the changes to a
92
+ callback function.
93
+
94
+ Args:
95
+ paths (list[str | Path]): List of paths to monitor for changes.
96
+ callback (Callable[[set[tuple[Change, str]]], None]): The callback
97
+ function to handle file changes.
98
+ stop_event (asyncio.Event): Event to signal when to stop watching.
99
+ **awatch_kwargs: Additional keyword arguments to pass to watchfiles.awatch.
100
+ """
101
+ str_paths = [str(path) for path in paths]
102
+ try:
103
+ async for changes in watchfiles.awatch(*str_paths, stop_event=stop_event, **awatch_kwargs):
104
+ callback(changes)
105
+ except Exception as e:
106
+ logger.error(f"Error watching files: {e}")
107
+
108
+
109
+ async def run_and_monitor(
110
+ program: str,
111
+ *args: str,
112
+ stdout: Callable[[str], None] | None = None,
113
+ stderr: Callable[[str], None] | None = None,
114
+ watch: Callable[[set[tuple[Change, str]]], None] | None = None,
115
+ paths: list[str | Path] | None = None,
116
+ **awatch_kwargs,
117
+ ) -> int:
118
+ """
119
+ Runs a command and optionally watch for file changes concurrently.
120
+
121
+ Args:
122
+ program (str): The program to run.
123
+ *args (str): Arguments for the program.
124
+ stdout (Callable[[str], None] | None): Callback for standard output.
125
+ stderr (Callable[[str], None] | None): Callback for standard error.
126
+ watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for
127
+ file changes.
128
+ paths (list[str | Path] | None): List of paths to monitor for changes.
129
+ """
130
+ stop_event = asyncio.Event()
131
+ run_task = asyncio.create_task(
132
+ execute_command(program, *args, stop_event=stop_event, stdout=stdout, stderr=stderr)
133
+ )
134
+ if watch and paths:
135
+ monitor_task = asyncio.create_task(
136
+ monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
137
+ )
138
+ else:
139
+ monitor_task = None
140
+
141
+ try:
142
+ if monitor_task:
143
+ await asyncio.gather(run_task, monitor_task)
144
+ else:
145
+ await run_task
146
+
147
+ except Exception as e:
148
+ logger.error(f"Error in run_and_monitor: {e}")
149
+ finally:
150
+ stop_event.set()
151
+ await run_task
152
+ if monitor_task:
153
+ await monitor_task
154
+
155
+ return run_task.result()
156
+
157
+
158
+ def run(
159
+ program: str,
160
+ *args: str,
161
+ stdout: Callable[[str], None] | None = None,
162
+ stderr: Callable[[str], None] | None = None,
163
+ watch: Callable[[set[tuple[Change, str]]], None] | None = None,
164
+ paths: list[str | Path] | None = None,
165
+ **awatch_kwargs,
166
+ ) -> int:
167
+ """
168
+ Run a command synchronously and optionally watch for file changes.
169
+
170
+ This function is a synchronous wrapper around the asynchronous `run_and_monitor` function.
171
+ It runs a specified command and optionally monitors specified paths for file changes,
172
+ invoking the provided callbacks for standard output, standard error, and file changes.
173
+
174
+ Args:
175
+ program (str): The program to run.
176
+ *args (str): Arguments for the program.
177
+ stdout (Callable[[str], None] | None): Callback for handling standard output lines.
178
+ stderr (Callable[[str], None] | None): Callback for handling standard error lines.
179
+ watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for handling file changes.
180
+ paths (list[str | Path] | None): List of paths to monitor for file changes.
181
+ **awatch_kwargs: Additional keyword arguments to pass to `watchfiles.awatch`.
182
+
183
+ Returns:
184
+ int: The return code of the process.
185
+ """
186
+ if watch and not paths:
187
+ paths = [Path.cwd()]
188
+
189
+ return asyncio.run(
190
+ run_and_monitor(
191
+ program,
192
+ *args,
193
+ stdout=stdout,
194
+ stderr=stderr,
195
+ watch=watch,
196
+ paths=paths,
197
+ **awatch_kwargs,
198
+ )
199
+ )
@@ -22,9 +22,9 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
22
22
  representing the parameters. The keys are prefixed with the provided prefix.
23
23
 
24
24
  Args:
25
- config: The configuration object to iterate over. This can be a dictionary,
26
- list, DictConfig, or ListConfig.
27
- prefix: The prefix to prepend to the parameter keys.
25
+ config (object): The configuration object to iterate over. This can be a
26
+ dictionary, list, DictConfig, or ListConfig.
27
+ prefix (str): The prefix to prepend to the parameter keys.
28
28
  Defaults to an empty string.
29
29
 
30
30
  Yields:
@@ -9,7 +9,6 @@ import logging
9
9
  import os
10
10
  import time
11
11
  from contextlib import contextmanager
12
- from dataclasses import dataclass
13
12
  from pathlib import Path
14
13
  from typing import TYPE_CHECKING
15
14
 
@@ -28,18 +27,12 @@ if TYPE_CHECKING:
28
27
  log = logging.getLogger(__name__)
29
28
 
30
29
 
31
- @dataclass
32
- class Info:
33
- output_dir: Path
34
- artifact_dir: Path
35
-
36
-
37
30
  @contextmanager
38
31
  def log_run(
39
32
  config: object,
40
33
  *,
41
34
  synchronous: bool | None = None,
42
- ) -> Iterator[Info]:
35
+ ) -> Iterator[None]:
43
36
  """
44
37
  Log the parameters from the given configuration object and manage the MLflow
45
38
  run context.
@@ -49,16 +42,15 @@ def log_run(
49
42
  are logged and the run is properly closed.
50
43
 
51
44
  Args:
52
- config: The configuration object to log the parameters from.
53
- synchronous: Whether to log the parameters synchronously.
45
+ config (object): The configuration object to log the parameters from.
46
+ synchronous (bool | None): Whether to log the parameters synchronously.
54
47
  Defaults to None.
55
48
 
56
49
  Yields:
57
- Info: An `Info` object containing the output directory and artifact directory
58
- paths.
50
+ None
59
51
 
60
52
  Example:
61
- with log_run(config) as info:
53
+ with log_run(config):
62
54
  # Perform operations within the MLflow run context
63
55
  pass
64
56
  """
@@ -66,7 +58,6 @@ def log_run(
66
58
 
67
59
  hc = HydraConfig.get()
68
60
  output_dir = Path(hc.runtime.output_dir)
69
- info = Info(output_dir, get_artifact_dir())
70
61
 
71
62
  # Save '.hydra' config directory first.
72
63
  output_subdir = output_dir / (hc.output_subdir or "")
@@ -78,7 +69,7 @@ def log_run(
78
69
 
79
70
  try:
80
71
  with watch(log_artifact, output_dir):
81
- yield info
72
+ yield
82
73
 
83
74
  except Exception as e:
84
75
  log.error(f"Error during log_run: {e}")
@@ -89,6 +80,64 @@ def log_run(
89
80
  mlflow.log_artifacts(output_dir.as_posix())
90
81
 
91
82
 
83
+ @contextmanager
84
+ def start_run(
85
+ config: object,
86
+ *,
87
+ run_id: str | None = None,
88
+ experiment_id: str | None = None,
89
+ run_name: str | None = None,
90
+ nested: bool = False,
91
+ parent_run_id: str | None = None,
92
+ tags: dict[str, str] | None = None,
93
+ description: str | None = None,
94
+ log_system_metrics: bool | None = None,
95
+ synchronous: bool | None = None,
96
+ ) -> Iterator[Run]:
97
+ """
98
+ Start an MLflow run and log parameters using the provided configuration object.
99
+
100
+ This context manager starts an MLflow run and logs parameters using the specified
101
+ configuration object. It ensures that the run is properly closed after completion.
102
+
103
+ Args:
104
+ config (object): The configuration object to log parameters from.
105
+ run_id (str | None): The existing run ID. Defaults to None.
106
+ experiment_id (str | None): The experiment ID. Defaults to None.
107
+ run_name (str | None): The name of the run. Defaults to None.
108
+ nested (bool): Whether to allow nested runs. Defaults to False.
109
+ parent_run_id (str | None): The parent run ID. Defaults to None.
110
+ tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
111
+ description (str | None): A description of the run. Defaults to None.
112
+ log_system_metrics (bool | None): Whether to log system metrics. Defaults to None.
113
+ synchronous (bool | None): Whether to log parameters synchronously. Defaults to None.
114
+
115
+ Yields:
116
+ Run: An MLflow Run object representing the started run.
117
+
118
+ Example:
119
+ with start_run(config) as run:
120
+ # Perform operations within the MLflow run context
121
+ pass
122
+
123
+ See Also:
124
+ `mlflow.start_run`: The MLflow function to start a run directly.
125
+ `log_run`: A context manager to log parameters and manage the MLflow run context.
126
+ """
127
+ with mlflow.start_run(
128
+ run_id=run_id,
129
+ experiment_id=experiment_id,
130
+ run_name=run_name,
131
+ nested=nested,
132
+ parent_run_id=parent_run_id,
133
+ tags=tags,
134
+ description=description,
135
+ log_system_metrics=log_system_metrics,
136
+ ) as run:
137
+ with log_run(config, synchronous=synchronous):
138
+ yield run
139
+
140
+
92
141
  @contextmanager
93
142
  def watch(
94
143
  func: Callable[[Path], None],
@@ -105,12 +154,12 @@ def watch(
105
154
  period or until the context is exited.
106
155
 
107
156
  Args:
108
- func: The function to call when a change is
157
+ func (Callable[[Path], None]): The function to call when a change is
109
158
  detected. It should accept a single argument of type `Path`,
110
159
  which is the path of the modified file.
111
- dir: The directory to watch. If not specified,
160
+ dir (Path | str): The directory to watch. If not specified,
112
161
  the current MLflow artifact URI is used. Defaults to "".
113
- timeout: The timeout period in seconds for the watcher
162
+ timeout (int): The timeout period in seconds for the watcher
114
163
  to run after the context is exited. Defaults to 60.
115
164
 
116
165
  Yields:
@@ -122,6 +171,8 @@ def watch(
122
171
  pass
123
172
  """
124
173
  dir = dir or get_artifact_dir()
174
+ if isinstance(dir, Path):
175
+ dir = dir.as_posix()
125
176
 
126
177
  handler = Handler(func)
127
178
  observer = Observer()
@@ -152,7 +203,7 @@ class Handler(FileSystemEventHandler):
152
203
  self.func = func
153
204
 
154
205
  def on_modified(self, event: FileModifiedEvent) -> None:
155
- file = Path(event.src_path)
206
+ file = Path(str(event.src_path))
156
207
  if file.is_file():
157
208
  self.func(file)
158
209
 
@@ -171,8 +222,8 @@ def chdir_artifact(
171
222
  to the original directory after the context is exited.
172
223
 
173
224
  Args:
174
- run: The run to get the artifact directory from.
175
- artifact_path: The artifact path.
225
+ run (Run): The run to get the artifact directory from.
226
+ artifact_path (str | None): The artifact path.
176
227
  """
177
228
  curdir = Path.cwd()
178
229
  path = mlflow.artifacts.download_artifacts(
@@ -0,0 +1,124 @@
1
+ """
2
+ This module provides functionality to log parameters from Hydra
3
+ configuration objects and set up experiments using MLflow.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING
10
+
11
+ import mlflow
12
+ from hydra.core.hydra_config import HydraConfig
13
+ from mlflow.tracking import artifact_utils
14
+ from omegaconf import OmegaConf
15
+
16
+ from hydraflow.config import iter_params
17
+
18
+ if TYPE_CHECKING:
19
+ from mlflow.entities.experiment import Experiment
20
+
21
+
22
+ def set_experiment(
23
+ prefix: str = "",
24
+ suffix: str = "",
25
+ uri: str | Path | None = None,
26
+ ) -> Experiment:
27
+ """
28
+ Set the experiment name and tracking URI optionally.
29
+
30
+ This function sets the experiment name by combining the given prefix,
31
+ the job name from HydraConfig, and the given suffix. Optionally, it can
32
+ also set the tracking URI.
33
+
34
+ Args:
35
+ prefix (str): The prefix to prepend to the experiment name.
36
+ suffix (str): The suffix to append to the experiment name.
37
+ uri (str | Path | None): The tracking URI to use. Defaults to None.
38
+
39
+ Returns:
40
+ Experiment: An instance of `mlflow.entities.Experiment` representing
41
+ the new active experiment.
42
+ """
43
+ if uri is not None:
44
+ mlflow.set_tracking_uri(uri)
45
+
46
+ hc = HydraConfig.get()
47
+ name = f"{prefix}{hc.job.name}{suffix}"
48
+ return mlflow.set_experiment(name)
49
+
50
+
51
+ def log_params(config: object, *, synchronous: bool | None = None) -> None:
52
+ """
53
+ Log the parameters from the given configuration object.
54
+
55
+ This method logs the parameters from the provided configuration object
56
+ using MLflow. It iterates over the parameters and logs them using the
57
+ `mlflow.log_param` method.
58
+
59
+ Args:
60
+ config (object): The configuration object to log the parameters from.
61
+ synchronous (bool | None): Whether to log the parameters synchronously.
62
+ Defaults to None.
63
+ """
64
+ for key, value in iter_params(config):
65
+ mlflow.log_param(key, value, synchronous=synchronous)
66
+
67
+
68
+ def get_artifact_dir(
69
+ artifact_path: str | None = None,
70
+ *,
71
+ run_id: str | None = None,
72
+ ) -> Path:
73
+ """
74
+ Get the artifact directory for the given artifact path.
75
+
76
+ This function retrieves the artifact URI for the specified artifact path
77
+ using MLflow, downloads the artifacts to a local directory, and returns
78
+ the path to that directory.
79
+
80
+ Args:
81
+ artifact_path (str | None): The artifact path for which to get the
82
+ directory. Defaults to None.
83
+ run_id (str | None): The run ID for which to get the artifact directory.
84
+
85
+ Returns:
86
+ The local path to the directory where the artifacts are downloaded.
87
+ """
88
+ if run_id is None:
89
+ uri = mlflow.get_artifact_uri(artifact_path)
90
+ else:
91
+ uri = artifact_utils.get_artifact_uri(run_id, artifact_path)
92
+
93
+ dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
94
+
95
+ return Path(dir)
96
+
97
+
98
+ def get_hydra_output_dir(*, run_id: str | None = None) -> Path:
99
+ if run_id is None:
100
+ hc = HydraConfig.get()
101
+ return Path(hc.runtime.output_dir)
102
+
103
+ path = get_artifact_dir(run_id=run_id) / ".hydra/hydra.yaml"
104
+
105
+ if path.exists():
106
+ hc = OmegaConf.load(path)
107
+ return Path(hc.hydra.runtime.output_dir)
108
+
109
+ raise FileNotFoundError
110
+
111
+
112
+ # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
113
+ # """
114
+ # Log the Hydra output directory.
115
+
116
+ # Args:
117
+ # run: The run object.
118
+
119
+ # Returns:
120
+ # None
121
+ # """
122
+ # output_dir = get_hydra_output_dir(run)
123
+ # run_id = run if isinstance(run, str) else run.info.run_id
124
+ # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)