hydraflow 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hydraflow/__init__.py CHANGED
@@ -1,28 +1,22 @@
1
- from .context import Info, chdir_artifact, log_run, watch
2
- from .mlflow import set_experiment
1
+ from .context import chdir_artifact, log_run, start_run, watch
2
+ from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
3
3
  from .runs import (
4
- Run,
5
4
  RunCollection,
6
- filter_runs,
7
- get_param_dict,
8
- get_param_names,
9
- get_run,
5
+ list_runs,
10
6
  load_config,
11
7
  search_runs,
12
8
  )
13
9
 
14
10
  __all__ = [
15
- "Info",
16
- "Run",
17
11
  "RunCollection",
18
12
  "chdir_artifact",
19
- "filter_runs",
20
- "get_param_dict",
21
- "get_param_names",
22
- "get_run",
13
+ "get_artifact_dir",
14
+ "get_hydra_output_dir",
15
+ "list_runs",
23
16
  "load_config",
24
17
  "log_run",
25
18
  "search_runs",
26
19
  "set_experiment",
20
+ "start_run",
27
21
  "watch",
28
22
  ]
hydraflow/asyncio.py ADDED
@@ -0,0 +1,199 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import logging
5
+ from asyncio.subprocess import PIPE
6
+ from pathlib import Path
7
+ from typing import TYPE_CHECKING
8
+
9
+ import watchfiles
10
+
11
+ if TYPE_CHECKING:
12
+ from asyncio.streams import StreamReader
13
+ from collections.abc import Callable
14
+
15
+ from watchfiles import Change
16
+
17
+
18
+ # Set up logging
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ async def execute_command(
24
+ program: str,
25
+ *args: str,
26
+ stdout: Callable[[str], None] | None = None,
27
+ stderr: Callable[[str], None] | None = None,
28
+ stop_event: asyncio.Event,
29
+ ) -> int:
30
+ """
31
+ Runs a command asynchronously and pass the output to callback functions.
32
+
33
+ Args:
34
+ program (str): The program to run.
35
+ *args (str): Arguments for the program.
36
+ stdout (Callable[[str], None] | None): Callback for standard output.
37
+ stderr (Callable[[str], None] | None): Callback for standard error.
38
+ stop_event (asyncio.Event): Event to signal when the process is done.
39
+
40
+ Returns:
41
+ int: The return code of the process.
42
+ """
43
+ try:
44
+ process = await asyncio.create_subprocess_exec(program, *args, stdout=PIPE, stderr=PIPE)
45
+ await asyncio.gather(
46
+ process_stream(process.stdout, stdout),
47
+ process_stream(process.stderr, stderr),
48
+ )
49
+ returncode = await process.wait()
50
+
51
+ except Exception as e:
52
+ logger.error(f"Error running command: {e}")
53
+ returncode = 1
54
+
55
+ finally:
56
+ stop_event.set()
57
+
58
+ return returncode
59
+
60
+
61
+ async def process_stream(
62
+ stream: StreamReader | None,
63
+ callback: Callable[[str], None] | None,
64
+ ) -> None:
65
+ """
66
+ Reads a stream asynchronously and pass each line to a callback function.
67
+
68
+ Args:
69
+ stream (StreamReader | None): The stream to read from.
70
+ callback (Callable[[str], None] | None): The callback function to handle
71
+ each line.
72
+ """
73
+ if stream is None or callback is None:
74
+ return
75
+
76
+ while True:
77
+ line = await stream.readline()
78
+ if line:
79
+ callback(line.decode().strip())
80
+ else:
81
+ break
82
+
83
+
84
+ async def monitor_file_changes(
85
+ paths: list[str | Path],
86
+ callback: Callable[[set[tuple[Change, str]]], None],
87
+ stop_event: asyncio.Event,
88
+ **awatch_kwargs,
89
+ ) -> None:
90
+ """
91
+ Watches for file changes in specified paths and pass the changes to a
92
+ callback function.
93
+
94
+ Args:
95
+ paths (list[str | Path]): List of paths to monitor for changes.
96
+ callback (Callable[[set[tuple[Change, str]]], None]): The callback
97
+ function to handle file changes.
98
+ stop_event (asyncio.Event): Event to signal when to stop watching.
99
+ **awatch_kwargs: Additional keyword arguments to pass to watchfiles.awatch.
100
+ """
101
+ str_paths = [str(path) for path in paths]
102
+ try:
103
+ async for changes in watchfiles.awatch(*str_paths, stop_event=stop_event, **awatch_kwargs):
104
+ callback(changes)
105
+ except Exception as e:
106
+ logger.error(f"Error watching files: {e}")
107
+
108
+
109
+ async def run_and_monitor(
110
+ program: str,
111
+ *args: str,
112
+ stdout: Callable[[str], None] | None = None,
113
+ stderr: Callable[[str], None] | None = None,
114
+ watch: Callable[[set[tuple[Change, str]]], None] | None = None,
115
+ paths: list[str | Path] | None = None,
116
+ **awatch_kwargs,
117
+ ) -> int:
118
+ """
119
+ Runs a command and optionally watch for file changes concurrently.
120
+
121
+ Args:
122
+ program (str): The program to run.
123
+ *args (str): Arguments for the program.
124
+ stdout (Callable[[str], None] | None): Callback for standard output.
125
+ stderr (Callable[[str], None] | None): Callback for standard error.
126
+ watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for
127
+ file changes.
128
+ paths (list[str | Path] | None): List of paths to monitor for changes.
129
+ """
130
+ stop_event = asyncio.Event()
131
+ run_task = asyncio.create_task(
132
+ execute_command(program, *args, stop_event=stop_event, stdout=stdout, stderr=stderr)
133
+ )
134
+ if watch and paths:
135
+ monitor_task = asyncio.create_task(
136
+ monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
137
+ )
138
+ else:
139
+ monitor_task = None
140
+
141
+ try:
142
+ if monitor_task:
143
+ await asyncio.gather(run_task, monitor_task)
144
+ else:
145
+ await run_task
146
+
147
+ except Exception as e:
148
+ logger.error(f"Error in run_and_monitor: {e}")
149
+ finally:
150
+ stop_event.set()
151
+ await run_task
152
+ if monitor_task:
153
+ await monitor_task
154
+
155
+ return run_task.result()
156
+
157
+
158
+ def run(
159
+ program: str,
160
+ *args: str,
161
+ stdout: Callable[[str], None] | None = None,
162
+ stderr: Callable[[str], None] | None = None,
163
+ watch: Callable[[set[tuple[Change, str]]], None] | None = None,
164
+ paths: list[str | Path] | None = None,
165
+ **awatch_kwargs,
166
+ ) -> int:
167
+ """
168
+ Run a command synchronously and optionally watch for file changes.
169
+
170
+ This function is a synchronous wrapper around the asynchronous `run_and_monitor` function.
171
+ It runs a specified command and optionally monitors specified paths for file changes,
172
+ invoking the provided callbacks for standard output, standard error, and file changes.
173
+
174
+ Args:
175
+ program (str): The program to run.
176
+ *args (str): Arguments for the program.
177
+ stdout (Callable[[str], None] | None): Callback for handling standard output lines.
178
+ stderr (Callable[[str], None] | None): Callback for handling standard error lines.
179
+ watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for handling file changes.
180
+ paths (list[str | Path] | None): List of paths to monitor for file changes.
181
+ **awatch_kwargs: Additional keyword arguments to pass to `watchfiles.awatch`.
182
+
183
+ Returns:
184
+ int: The return code of the process.
185
+ """
186
+ if watch and not paths:
187
+ paths = [Path.cwd()]
188
+
189
+ return asyncio.run(
190
+ run_and_monitor(
191
+ program,
192
+ *args,
193
+ stdout=stdout,
194
+ stderr=stderr,
195
+ watch=watch,
196
+ paths=paths,
197
+ **awatch_kwargs,
198
+ )
199
+ )
hydraflow/config.py CHANGED
@@ -22,9 +22,9 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
22
22
  representing the parameters. The keys are prefixed with the provided prefix.
23
23
 
24
24
  Args:
25
- config: The configuration object to iterate over. This can be a dictionary,
26
- list, DictConfig, or ListConfig.
27
- prefix: The prefix to prepend to the parameter keys.
25
+ config (object): The configuration object to iterate over. This can be a
26
+ dictionary, list, DictConfig, or ListConfig.
27
+ prefix (str): The prefix to prepend to the parameter keys.
28
28
  Defaults to an empty string.
29
29
 
30
30
  Yields:
hydraflow/context.py CHANGED
@@ -9,7 +9,6 @@ import logging
9
9
  import os
10
10
  import time
11
11
  from contextlib import contextmanager
12
- from dataclasses import dataclass
13
12
  from pathlib import Path
14
13
  from typing import TYPE_CHECKING
15
14
 
@@ -28,18 +27,12 @@ if TYPE_CHECKING:
28
27
  log = logging.getLogger(__name__)
29
28
 
30
29
 
31
- @dataclass
32
- class Info:
33
- output_dir: Path
34
- artifact_dir: Path
35
-
36
-
37
30
  @contextmanager
38
31
  def log_run(
39
32
  config: object,
40
33
  *,
41
34
  synchronous: bool | None = None,
42
- ) -> Iterator[Info]:
35
+ ) -> Iterator[None]:
43
36
  """
44
37
  Log the parameters from the given configuration object and manage the MLflow
45
38
  run context.
@@ -49,16 +42,15 @@ def log_run(
49
42
  are logged and the run is properly closed.
50
43
 
51
44
  Args:
52
- config: The configuration object to log the parameters from.
53
- synchronous: Whether to log the parameters synchronously.
45
+ config (object): The configuration object to log the parameters from.
46
+ synchronous (bool | None): Whether to log the parameters synchronously.
54
47
  Defaults to None.
55
48
 
56
49
  Yields:
57
- Info: An `Info` object containing the output directory and artifact directory
58
- paths.
50
+ None
59
51
 
60
52
  Example:
61
- with log_run(config) as info:
53
+ with log_run(config):
62
54
  # Perform operations within the MLflow run context
63
55
  pass
64
56
  """
@@ -66,7 +58,6 @@ def log_run(
66
58
 
67
59
  hc = HydraConfig.get()
68
60
  output_dir = Path(hc.runtime.output_dir)
69
- info = Info(output_dir, get_artifact_dir())
70
61
 
71
62
  # Save '.hydra' config directory first.
72
63
  output_subdir = output_dir / (hc.output_subdir or "")
@@ -78,7 +69,7 @@ def log_run(
78
69
 
79
70
  try:
80
71
  with watch(log_artifact, output_dir):
81
- yield info
72
+ yield
82
73
 
83
74
  except Exception as e:
84
75
  log.error(f"Error during log_run: {e}")
@@ -89,6 +80,64 @@ def log_run(
89
80
  mlflow.log_artifacts(output_dir.as_posix())
90
81
 
91
82
 
83
+ @contextmanager
84
+ def start_run(
85
+ config: object,
86
+ *,
87
+ run_id: str | None = None,
88
+ experiment_id: str | None = None,
89
+ run_name: str | None = None,
90
+ nested: bool = False,
91
+ parent_run_id: str | None = None,
92
+ tags: dict[str, str] | None = None,
93
+ description: str | None = None,
94
+ log_system_metrics: bool | None = None,
95
+ synchronous: bool | None = None,
96
+ ) -> Iterator[Run]:
97
+ """
98
+ Start an MLflow run and log parameters using the provided configuration object.
99
+
100
+ This context manager starts an MLflow run and logs parameters using the specified
101
+ configuration object. It ensures that the run is properly closed after completion.
102
+
103
+ Args:
104
+ config (object): The configuration object to log parameters from.
105
+ run_id (str | None): The existing run ID. Defaults to None.
106
+ experiment_id (str | None): The experiment ID. Defaults to None.
107
+ run_name (str | None): The name of the run. Defaults to None.
108
+ nested (bool): Whether to allow nested runs. Defaults to False.
109
+ parent_run_id (str | None): The parent run ID. Defaults to None.
110
+ tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
111
+ description (str | None): A description of the run. Defaults to None.
112
+ log_system_metrics (bool | None): Whether to log system metrics. Defaults to None.
113
+ synchronous (bool | None): Whether to log parameters synchronously. Defaults to None.
114
+
115
+ Yields:
116
+ Run: An MLflow Run object representing the started run.
117
+
118
+ Example:
119
+ with start_run(config) as run:
120
+ # Perform operations within the MLflow run context
121
+ pass
122
+
123
+ See Also:
124
+ `mlflow.start_run`: The MLflow function to start a run directly.
125
+ `log_run`: A context manager to log parameters and manage the MLflow run context.
126
+ """
127
+ with mlflow.start_run(
128
+ run_id=run_id,
129
+ experiment_id=experiment_id,
130
+ run_name=run_name,
131
+ nested=nested,
132
+ parent_run_id=parent_run_id,
133
+ tags=tags,
134
+ description=description,
135
+ log_system_metrics=log_system_metrics,
136
+ ) as run:
137
+ with log_run(config, synchronous=synchronous):
138
+ yield run
139
+
140
+
92
141
  @contextmanager
93
142
  def watch(
94
143
  func: Callable[[Path], None],
@@ -105,12 +154,12 @@ def watch(
105
154
  period or until the context is exited.
106
155
 
107
156
  Args:
108
- func: The function to call when a change is
157
+ func (Callable[[Path], None]): The function to call when a change is
109
158
  detected. It should accept a single argument of type `Path`,
110
159
  which is the path of the modified file.
111
- dir: The directory to watch. If not specified,
160
+ dir (Path | str): The directory to watch. If not specified,
112
161
  the current MLflow artifact URI is used. Defaults to "".
113
- timeout: The timeout period in seconds for the watcher
162
+ timeout (int): The timeout period in seconds for the watcher
114
163
  to run after the context is exited. Defaults to 60.
115
164
 
116
165
  Yields:
@@ -122,6 +171,8 @@ def watch(
122
171
  pass
123
172
  """
124
173
  dir = dir or get_artifact_dir()
174
+ if isinstance(dir, Path):
175
+ dir = dir.as_posix()
125
176
 
126
177
  handler = Handler(func)
127
178
  observer = Observer()
@@ -152,7 +203,7 @@ class Handler(FileSystemEventHandler):
152
203
  self.func = func
153
204
 
154
205
  def on_modified(self, event: FileModifiedEvent) -> None:
155
- file = Path(event.src_path)
206
+ file = Path(str(event.src_path))
156
207
  if file.is_file():
157
208
  self.func(file)
158
209
 
@@ -171,8 +222,8 @@ def chdir_artifact(
171
222
  to the original directory after the context is exited.
172
223
 
173
224
  Args:
174
- run: The run to get the artifact directory from.
175
- artifact_path: The artifact path.
225
+ run (Run): The run to get the artifact directory from.
226
+ artifact_path (str | None): The artifact path.
176
227
  """
177
228
  curdir = Path.cwd()
178
229
  path = mlflow.artifacts.download_artifacts(
hydraflow/mlflow.py CHANGED
@@ -6,14 +6,24 @@ configuration objects and set up experiments using MLflow.
6
6
  from __future__ import annotations
7
7
 
8
8
  from pathlib import Path
9
+ from typing import TYPE_CHECKING
9
10
 
10
11
  import mlflow
11
12
  from hydra.core.hydra_config import HydraConfig
13
+ from mlflow.tracking import artifact_utils
14
+ from omegaconf import OmegaConf
12
15
 
13
16
  from hydraflow.config import iter_params
14
17
 
18
+ if TYPE_CHECKING:
19
+ from mlflow.entities.experiment import Experiment
15
20
 
16
- def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -> None:
21
+
22
+ def set_experiment(
23
+ prefix: str = "",
24
+ suffix: str = "",
25
+ uri: str | Path | None = None,
26
+ ) -> Experiment:
17
27
  """
18
28
  Set the experiment name and tracking URI optionally.
19
29
 
@@ -22,16 +32,20 @@ def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -
22
32
  also set the tracking URI.
23
33
 
24
34
  Args:
25
- prefix: The prefix to prepend to the experiment name.
26
- suffix: The suffix to append to the experiment name.
27
- uri: The tracking URI to use.
35
+ prefix (str): The prefix to prepend to the experiment name.
36
+ suffix (str): The suffix to append to the experiment name.
37
+ uri (str | Path | None): The tracking URI to use. Defaults to None.
38
+
39
+ Returns:
40
+ Experiment: An instance of `mlflow.entities.Experiment` representing
41
+ the new active experiment.
28
42
  """
29
- if uri:
43
+ if uri is not None:
30
44
  mlflow.set_tracking_uri(uri)
31
45
 
32
46
  hc = HydraConfig.get()
33
47
  name = f"{prefix}{hc.job.name}{suffix}"
34
- mlflow.set_experiment(name)
48
+ return mlflow.set_experiment(name)
35
49
 
36
50
 
37
51
  def log_params(config: object, *, synchronous: bool | None = None) -> None:
@@ -43,15 +57,19 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
43
57
  `mlflow.log_param` method.
44
58
 
45
59
  Args:
46
- config: The configuration object to log the parameters from.
47
- synchronous: Whether to log the parameters synchronously.
60
+ config (object): The configuration object to log the parameters from.
61
+ synchronous (bool | None): Whether to log the parameters synchronously.
48
62
  Defaults to None.
49
63
  """
50
64
  for key, value in iter_params(config):
51
65
  mlflow.log_param(key, value, synchronous=synchronous)
52
66
 
53
67
 
54
- def get_artifact_dir(artifact_path: str | None = None) -> Path:
68
+ def get_artifact_dir(
69
+ artifact_path: str | None = None,
70
+ *,
71
+ run_id: str | None = None,
72
+ ) -> Path:
55
73
  """
56
74
  Get the artifact directory for the given artifact path.
57
75
 
@@ -60,13 +78,47 @@ def get_artifact_dir(artifact_path: str | None = None) -> Path:
60
78
  the path to that directory.
61
79
 
62
80
  Args:
63
- artifact_path: The artifact path for which to get the directory.
64
- Defaults to None.
81
+ artifact_path (str | None): The artifact path for which to get the
82
+ directory. Defaults to None.
83
+ run_id (str | None): The run ID for which to get the artifact directory.
65
84
 
66
85
  Returns:
67
86
  The local path to the directory where the artifacts are downloaded.
68
87
  """
69
- uri = mlflow.get_artifact_uri(artifact_path)
88
+ if run_id is None:
89
+ uri = mlflow.get_artifact_uri(artifact_path)
90
+ else:
91
+ uri = artifact_utils.get_artifact_uri(run_id, artifact_path)
92
+
70
93
  dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
71
94
 
72
95
  return Path(dir)
96
+
97
+
98
+ def get_hydra_output_dir(*, run_id: str | None = None) -> Path:
99
+ if run_id is None:
100
+ hc = HydraConfig.get()
101
+ return Path(hc.runtime.output_dir)
102
+
103
+ path = get_artifact_dir(run_id=run_id) / ".hydra/hydra.yaml"
104
+
105
+ if path.exists():
106
+ hc = OmegaConf.load(path)
107
+ return Path(hc.hydra.runtime.output_dir)
108
+
109
+ raise FileNotFoundError
110
+
111
+
112
+ # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
113
+ # """
114
+ # Log the Hydra output directory.
115
+
116
+ # Args:
117
+ # run: The run object.
118
+
119
+ # Returns:
120
+ # None
121
+ # """
122
+ # output_dir = get_hydra_output_dir(run)
123
+ # run_id = run if isinstance(run, str) else run.info.run_id
124
+ # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)