hydraflow 0.2.3__tar.gz → 0.2.5__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (29) hide show
  1. {hydraflow-0.2.3 → hydraflow-0.2.5}/.devcontainer/devcontainer.json +1 -2
  2. {hydraflow-0.2.3 → hydraflow-0.2.5}/PKG-INFO +4 -7
  3. {hydraflow-0.2.3 → hydraflow-0.2.5}/README.md +3 -6
  4. {hydraflow-0.2.3 → hydraflow-0.2.5}/pyproject.toml +1 -1
  5. hydraflow-0.2.5/src/hydraflow/__init__.py +22 -0
  6. {hydraflow-0.2.3 → hydraflow-0.2.5}/src/hydraflow/config.py +3 -3
  7. {hydraflow-0.2.3 → hydraflow-0.2.5}/src/hydraflow/context.py +72 -21
  8. hydraflow-0.2.5/src/hydraflow/mlflow.py +124 -0
  9. {hydraflow-0.2.3 → hydraflow-0.2.5}/src/hydraflow/runs.py +123 -94
  10. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/scripts/log_run.py +7 -5
  11. hydraflow-0.2.5/tests/test_context.py +80 -0
  12. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/test_log_run.py +1 -1
  13. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/test_runs.py +38 -0
  14. hydraflow-0.2.3/src/hydraflow/__init__.py +0 -30
  15. hydraflow-0.2.3/src/hydraflow/mlflow.py +0 -72
  16. hydraflow-0.2.3/tests/test_context.py +0 -36
  17. {hydraflow-0.2.3 → hydraflow-0.2.5}/.devcontainer/postCreate.sh +0 -0
  18. {hydraflow-0.2.3 → hydraflow-0.2.5}/.devcontainer/starship.toml +0 -0
  19. {hydraflow-0.2.3 → hydraflow-0.2.5}/.gitattributes +0 -0
  20. {hydraflow-0.2.3 → hydraflow-0.2.5}/.gitignore +0 -0
  21. {hydraflow-0.2.3 → hydraflow-0.2.5}/LICENSE +0 -0
  22. {hydraflow-0.2.3 → hydraflow-0.2.5}/src/hydraflow/asyncio.py +0 -0
  23. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/scripts/__init__.py +0 -0
  24. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/scripts/watch.py +0 -0
  25. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/test_asyncio.py +0 -0
  26. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/test_config.py +0 -0
  27. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/test_mlflow.py +0 -0
  28. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/test_version.py +0 -0
  29. {hydraflow-0.2.3 → hydraflow-0.2.5}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "hydraflow",
3
- "image": "mcr.microsoft.com/vscode/devcontainers/base:ubuntu-22.04",
3
+ "image": "mcr.microsoft.com/vscode/devcontainers/python:3.12",
4
4
  "features": {
5
5
  "ghcr.io/devcontainers-contrib/features/starship:1": {},
6
6
  "ghcr.io/va-h/devcontainers-features/uv:1": {}
@@ -9,7 +9,6 @@
9
9
  "vscode": {
10
10
  "extensions": [
11
11
  "charliermarsh.ruff",
12
- "henriiik.vscode-sort",
13
12
  "ms-python.python",
14
13
  "ms-python.vscode-pylance"
15
14
  ]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -48,7 +48,7 @@ Description-Content-Type: text/markdown
48
48
 
49
49
  ## Overview
50
50
 
51
- Hydraflow is a powerful library designed to seamlessly integrate
51
+ Hydraflow is a library designed to seamlessly integrate
52
52
  [Hydra](https://hydra.cc/) and [MLflow](https://mlflow.org/), making it easier to
53
53
  manage and track machine learning experiments. By combining the flexibility of
54
54
  Hydra's configuration management with the robust experiment tracking capabilities
@@ -99,13 +99,10 @@ def my_app(cfg: MySQLConfig) -> None:
99
99
  # Set experiment by Hydra job name.
100
100
  hydraflow.set_experiment()
101
101
 
102
- # Automatically log params using Hydra config.
103
- with mlflow.start_run(), hydraflow.log_run(cfg) as info:
102
+ # Automatically log Hydra config as params.
103
+ with hydraflow.start_run():
104
104
  # Your app code below.
105
105
 
106
- # `info.output_dir` is the Hydra output directory.
107
- # `info.artifact_dir` is the MLflow artifact directory.
108
-
109
106
  with hydraflow.watch(callback):
110
107
  # Watch files in the MLflow artifact directory.
111
108
  # You can update metrics or log other artifacts
@@ -17,7 +17,7 @@
17
17
 
18
18
  ## Overview
19
19
 
20
- Hydraflow is a powerful library designed to seamlessly integrate
20
+ Hydraflow is a library designed to seamlessly integrate
21
21
  [Hydra](https://hydra.cc/) and [MLflow](https://mlflow.org/), making it easier to
22
22
  manage and track machine learning experiments. By combining the flexibility of
23
23
  Hydra's configuration management with the robust experiment tracking capabilities
@@ -68,13 +68,10 @@ def my_app(cfg: MySQLConfig) -> None:
68
68
  # Set experiment by Hydra job name.
69
69
  hydraflow.set_experiment()
70
70
 
71
- # Automatically log params using Hydra config.
72
- with mlflow.start_run(), hydraflow.log_run(cfg) as info:
71
+ # Automatically log Hydra config as params.
72
+ with hydraflow.start_run():
73
73
  # Your app code below.
74
74
 
75
- # `info.output_dir` is the Hydra output directory.
76
- # `info.artifact_dir` is the MLflow artifact directory.
77
-
78
75
  with hydraflow.watch(callback):
79
76
  # Watch files in the MLflow artifact directory.
80
77
  # You can update metrics or log other artifacts
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.3"
7
+ version = "0.2.5"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -0,0 +1,22 @@
1
+ from .context import chdir_artifact, log_run, start_run, watch
2
+ from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
3
+ from .runs import (
4
+ RunCollection,
5
+ list_runs,
6
+ load_config,
7
+ search_runs,
8
+ )
9
+
10
+ __all__ = [
11
+ "RunCollection",
12
+ "chdir_artifact",
13
+ "get_artifact_dir",
14
+ "get_hydra_output_dir",
15
+ "list_runs",
16
+ "load_config",
17
+ "log_run",
18
+ "search_runs",
19
+ "set_experiment",
20
+ "start_run",
21
+ "watch",
22
+ ]
@@ -22,9 +22,9 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
22
22
  representing the parameters. The keys are prefixed with the provided prefix.
23
23
 
24
24
  Args:
25
- config: The configuration object to iterate over. This can be a dictionary,
26
- list, DictConfig, or ListConfig.
27
- prefix: The prefix to prepend to the parameter keys.
25
+ config (object): The configuration object to iterate over. This can be a
26
+ dictionary, list, DictConfig, or ListConfig.
27
+ prefix (str): The prefix to prepend to the parameter keys.
28
28
  Defaults to an empty string.
29
29
 
30
30
  Yields:
@@ -9,7 +9,6 @@ import logging
9
9
  import os
10
10
  import time
11
11
  from contextlib import contextmanager
12
- from dataclasses import dataclass
13
12
  from pathlib import Path
14
13
  from typing import TYPE_CHECKING
15
14
 
@@ -28,18 +27,12 @@ if TYPE_CHECKING:
28
27
  log = logging.getLogger(__name__)
29
28
 
30
29
 
31
- @dataclass
32
- class Info:
33
- output_dir: Path
34
- artifact_dir: Path
35
-
36
-
37
30
  @contextmanager
38
31
  def log_run(
39
32
  config: object,
40
33
  *,
41
34
  synchronous: bool | None = None,
42
- ) -> Iterator[Info]:
35
+ ) -> Iterator[None]:
43
36
  """
44
37
  Log the parameters from the given configuration object and manage the MLflow
45
38
  run context.
@@ -49,16 +42,15 @@ def log_run(
49
42
  are logged and the run is properly closed.
50
43
 
51
44
  Args:
52
- config: The configuration object to log the parameters from.
53
- synchronous: Whether to log the parameters synchronously.
45
+ config (object): The configuration object to log the parameters from.
46
+ synchronous (bool | None): Whether to log the parameters synchronously.
54
47
  Defaults to None.
55
48
 
56
49
  Yields:
57
- Info: An `Info` object containing the output directory and artifact directory
58
- paths.
50
+ None
59
51
 
60
52
  Example:
61
- with log_run(config) as info:
53
+ with log_run(config):
62
54
  # Perform operations within the MLflow run context
63
55
  pass
64
56
  """
@@ -66,7 +58,6 @@ def log_run(
66
58
 
67
59
  hc = HydraConfig.get()
68
60
  output_dir = Path(hc.runtime.output_dir)
69
- info = Info(output_dir, get_artifact_dir())
70
61
 
71
62
  # Save '.hydra' config directory first.
72
63
  output_subdir = output_dir / (hc.output_subdir or "")
@@ -78,7 +69,7 @@ def log_run(
78
69
 
79
70
  try:
80
71
  with watch(log_artifact, output_dir):
81
- yield info
72
+ yield
82
73
 
83
74
  except Exception as e:
84
75
  log.error(f"Error during log_run: {e}")
@@ -89,6 +80,64 @@ def log_run(
89
80
  mlflow.log_artifacts(output_dir.as_posix())
90
81
 
91
82
 
83
+ @contextmanager
84
+ def start_run(
85
+ config: object,
86
+ *,
87
+ run_id: str | None = None,
88
+ experiment_id: str | None = None,
89
+ run_name: str | None = None,
90
+ nested: bool = False,
91
+ parent_run_id: str | None = None,
92
+ tags: dict[str, str] | None = None,
93
+ description: str | None = None,
94
+ log_system_metrics: bool | None = None,
95
+ synchronous: bool | None = None,
96
+ ) -> Iterator[Run]:
97
+ """
98
+ Start an MLflow run and log parameters using the provided configuration object.
99
+
100
+ This context manager starts an MLflow run and logs parameters using the specified
101
+ configuration object. It ensures that the run is properly closed after completion.
102
+
103
+ Args:
104
+ config (object): The configuration object to log parameters from.
105
+ run_id (str | None): The existing run ID. Defaults to None.
106
+ experiment_id (str | None): The experiment ID. Defaults to None.
107
+ run_name (str | None): The name of the run. Defaults to None.
108
+ nested (bool): Whether to allow nested runs. Defaults to False.
109
+ parent_run_id (str | None): The parent run ID. Defaults to None.
110
+ tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
111
+ description (str | None): A description of the run. Defaults to None.
112
+ log_system_metrics (bool | None): Whether to log system metrics. Defaults to None.
113
+ synchronous (bool | None): Whether to log parameters synchronously. Defaults to None.
114
+
115
+ Yields:
116
+ Run: An MLflow Run object representing the started run.
117
+
118
+ Example:
119
+ with start_run(config) as run:
120
+ # Perform operations within the MLflow run context
121
+ pass
122
+
123
+ See Also:
124
+ `mlflow.start_run`: The MLflow function to start a run directly.
125
+ `log_run`: A context manager to log parameters and manage the MLflow run context.
126
+ """
127
+ with mlflow.start_run(
128
+ run_id=run_id,
129
+ experiment_id=experiment_id,
130
+ run_name=run_name,
131
+ nested=nested,
132
+ parent_run_id=parent_run_id,
133
+ tags=tags,
134
+ description=description,
135
+ log_system_metrics=log_system_metrics,
136
+ ) as run:
137
+ with log_run(config, synchronous=synchronous):
138
+ yield run
139
+
140
+
92
141
  @contextmanager
93
142
  def watch(
94
143
  func: Callable[[Path], None],
@@ -105,12 +154,12 @@ def watch(
105
154
  period or until the context is exited.
106
155
 
107
156
  Args:
108
- func: The function to call when a change is
157
+ func (Callable[[Path], None]): The function to call when a change is
109
158
  detected. It should accept a single argument of type `Path`,
110
159
  which is the path of the modified file.
111
- dir: The directory to watch. If not specified,
160
+ dir (Path | str): The directory to watch. If not specified,
112
161
  the current MLflow artifact URI is used. Defaults to "".
113
- timeout: The timeout period in seconds for the watcher
162
+ timeout (int): The timeout period in seconds for the watcher
114
163
  to run after the context is exited. Defaults to 60.
115
164
 
116
165
  Yields:
@@ -122,6 +171,8 @@ def watch(
122
171
  pass
123
172
  """
124
173
  dir = dir or get_artifact_dir()
174
+ if isinstance(dir, Path):
175
+ dir = dir.as_posix()
125
176
 
126
177
  handler = Handler(func)
127
178
  observer = Observer()
@@ -152,7 +203,7 @@ class Handler(FileSystemEventHandler):
152
203
  self.func = func
153
204
 
154
205
  def on_modified(self, event: FileModifiedEvent) -> None:
155
- file = Path(event.src_path)
206
+ file = Path(str(event.src_path))
156
207
  if file.is_file():
157
208
  self.func(file)
158
209
 
@@ -171,8 +222,8 @@ def chdir_artifact(
171
222
  to the original directory after the context is exited.
172
223
 
173
224
  Args:
174
- run: The run to get the artifact directory from.
175
- artifact_path: The artifact path.
225
+ run (Run): The run to get the artifact directory from.
226
+ artifact_path (str | None): The artifact path.
176
227
  """
177
228
  curdir = Path.cwd()
178
229
  path = mlflow.artifacts.download_artifacts(
@@ -0,0 +1,124 @@
1
+ """
2
+ This module provides functionality to log parameters from Hydra
3
+ configuration objects and set up experiments using MLflow.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ from pathlib import Path
9
+ from typing import TYPE_CHECKING
10
+
11
+ import mlflow
12
+ from hydra.core.hydra_config import HydraConfig
13
+ from mlflow.tracking import artifact_utils
14
+ from omegaconf import OmegaConf
15
+
16
+ from hydraflow.config import iter_params
17
+
18
+ if TYPE_CHECKING:
19
+ from mlflow.entities.experiment import Experiment
20
+
21
+
22
+ def set_experiment(
23
+ prefix: str = "",
24
+ suffix: str = "",
25
+ uri: str | Path | None = None,
26
+ ) -> Experiment:
27
+ """
28
+ Set the experiment name and tracking URI optionally.
29
+
30
+ This function sets the experiment name by combining the given prefix,
31
+ the job name from HydraConfig, and the given suffix. Optionally, it can
32
+ also set the tracking URI.
33
+
34
+ Args:
35
+ prefix (str): The prefix to prepend to the experiment name.
36
+ suffix (str): The suffix to append to the experiment name.
37
+ uri (str | Path | None): The tracking URI to use. Defaults to None.
38
+
39
+ Returns:
40
+ Experiment: An instance of `mlflow.entities.Experiment` representing
41
+ the new active experiment.
42
+ """
43
+ if uri is not None:
44
+ mlflow.set_tracking_uri(uri)
45
+
46
+ hc = HydraConfig.get()
47
+ name = f"{prefix}{hc.job.name}{suffix}"
48
+ return mlflow.set_experiment(name)
49
+
50
+
51
+ def log_params(config: object, *, synchronous: bool | None = None) -> None:
52
+ """
53
+ Log the parameters from the given configuration object.
54
+
55
+ This method logs the parameters from the provided configuration object
56
+ using MLflow. It iterates over the parameters and logs them using the
57
+ `mlflow.log_param` method.
58
+
59
+ Args:
60
+ config (object): The configuration object to log the parameters from.
61
+ synchronous (bool | None): Whether to log the parameters synchronously.
62
+ Defaults to None.
63
+ """
64
+ for key, value in iter_params(config):
65
+ mlflow.log_param(key, value, synchronous=synchronous)
66
+
67
+
68
+ def get_artifact_dir(
69
+ artifact_path: str | None = None,
70
+ *,
71
+ run_id: str | None = None,
72
+ ) -> Path:
73
+ """
74
+ Get the artifact directory for the given artifact path.
75
+
76
+ This function retrieves the artifact URI for the specified artifact path
77
+ using MLflow, downloads the artifacts to a local directory, and returns
78
+ the path to that directory.
79
+
80
+ Args:
81
+ artifact_path (str | None): The artifact path for which to get the
82
+ directory. Defaults to None.
83
+ run_id (str | None): The run ID for which to get the artifact directory.
84
+
85
+ Returns:
86
+ The local path to the directory where the artifacts are downloaded.
87
+ """
88
+ if run_id is None:
89
+ uri = mlflow.get_artifact_uri(artifact_path)
90
+ else:
91
+ uri = artifact_utils.get_artifact_uri(run_id, artifact_path)
92
+
93
+ dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
94
+
95
+ return Path(dir)
96
+
97
+
98
+ def get_hydra_output_dir(*, run_id: str | None = None) -> Path:
99
+ if run_id is None:
100
+ hc = HydraConfig.get()
101
+ return Path(hc.runtime.output_dir)
102
+
103
+ path = get_artifact_dir(run_id=run_id) / ".hydra/hydra.yaml"
104
+
105
+ if path.exists():
106
+ hc = OmegaConf.load(path)
107
+ return Path(hc.hydra.runtime.output_dir)
108
+
109
+ raise FileNotFoundError
110
+
111
+
112
+ # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
113
+ # """
114
+ # Log the Hydra output directory.
115
+
116
+ # Args:
117
+ # run: The run object.
118
+
119
+ # Returns:
120
+ # None
121
+ # """
122
+ # output_dir = get_hydra_output_dir(run)
123
+ # run_id = run if isinstance(run, str) else run.info.run_id
124
+ # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
@@ -45,28 +45,38 @@ def search_runs(
45
45
  The returned runs are sorted by their start time in ascending order.
46
46
 
47
47
  Args:
48
- experiment_ids: List of experiment IDs. Search can work with experiment
49
- IDs or experiment names, but not both in the same call. Values
50
- other than ``None`` or ``[]`` will result in error if
48
+ experiment_ids (list[str] | None): List of experiment IDs. Search can
49
+ work with experiment IDs or experiment names, but not both in the
50
+ same call. Values other than ``None`` or ``[]`` will result in
51
+ error if ``experiment_names`` is also not ``None`` or ``[]``.
52
+ ``None`` will default to the active experiment if ``experiment_names``
53
+ is ``None`` or ``[]``.
54
+ experiment_ids (list[str] | None): List of experiment IDs. Search can
55
+ work with experiment IDs or experiment names, but not both in the
56
+ same call. Values other than ``None`` or ``[]`` will result in
57
+ error if ``experiment_names`` is also not ``None`` or ``[]``.
51
58
  ``experiment_names`` is also not ``None`` or ``[]``. ``None`` will
52
59
  default to the active experiment if ``experiment_names`` is ``None``
53
60
  or ``[]``.
54
- filter_string: Filter query string, defaults to searching all runs.
55
- run_view_type: one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``, or
56
- ``ALL`` runs defined in :py:class:`mlflow.entities.ViewType`.
57
- max_results: The maximum number of runs to put in the dataframe. Default
58
- is 100,000 to avoid causing out-of-memory issues on the user's
61
+ filter_string (str): Filter query string, defaults to searching all
62
+ runs.
63
+ run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
64
+ or ``ALL`` runs defined in :py:class:`mlflow.entities.ViewType`.
65
+ max_results (int): The maximum number of runs to put in the dataframe.
66
+ Default is 100,000 to avoid causing out-of-memory issues on the user's
59
67
  machine.
60
- order_by: List of columns to order by (e.g., "metrics.rmse"). The
61
- ``order_by`` column can contain an optional ``DESC`` or ``ASC``
62
- value. The default is ``ASC``. The default ordering is to sort by
68
+ order_by (list[str] | None): List of columns to order by (e.g.,
69
+ "metrics.rmse"). The ``order_by`` column can contain an optional
70
+ ``DESC`` or ``ASC`` value. The default is ``ASC``. The default
71
+ ordering is to sort by ``start_time DESC``, then ``run_id``.
63
72
  ``start_time DESC``, then ``run_id``.
64
- search_all_experiments: Boolean specifying whether all experiments
65
- should be searched. Only honored if ``experiment_ids`` is ``[]`` or
66
- ``None``.
67
- experiment_names: List of experiment names. Search can work with
68
- experiment IDs or experiment names, but not both in the same call.
69
- Values other than ``None`` or ``[]`` will result in error if
73
+ search_all_experiments (bool): Boolean specifying whether all
74
+ experiments should be searched. Only honored if ``experiment_ids``
75
+ is ``[]`` or ``None``.
76
+ experiment_names (list[str] | None): List of experiment names. Search
77
+ can work with experiment IDs or experiment names, but not both in
78
+ the same call. Values other than ``None`` or ``[]`` will result in
79
+ error if ``experiment_ids`` is also not ``None`` or ``[]``.
70
80
  ``experiment_ids`` is also not ``None`` or ``[]``. ``None`` will
71
81
  default to the active experiment if ``experiment_ids`` is ``None``
72
82
  or ``[]``.
@@ -102,10 +112,10 @@ def list_runs(experiment_names: list[str] | None = None) -> RunCollection:
102
112
  The returned runs are sorted by their start time in ascending order.
103
113
 
104
114
  Args:
105
- experiment_names: List of experiment names to search for runs.
106
- If None or an empty list is provided, the function will search
107
- the currently active experiment or all experiments except the
108
- "Default" experiment.
115
+ experiment_names (list[str] | None): List of experiment names to search
116
+ for runs. If None or an empty list is provided, the function will
117
+ search the currently active experiment or all experiments except
118
+ the "Default" experiment.
109
119
 
110
120
  Returns:
111
121
  A `RunCollection` object containing the runs for the specified experiments.
@@ -138,6 +148,22 @@ class RunCollection:
138
148
  def __len__(self) -> int:
139
149
  return len(self._runs)
140
150
 
151
+ def __iter__(self) -> Iterator[Run]:
152
+ return iter(self._runs)
153
+
154
+ def __getitem__(self, index: int) -> Run:
155
+ return self._runs[index]
156
+
157
+ def __contains__(self, run: Run) -> bool:
158
+ return run in self._runs
159
+
160
+ def sort(
161
+ self,
162
+ key: Callable[[Run], Any] | None = None,
163
+ reverse: bool = False,
164
+ ) -> None:
165
+ self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
166
+
141
167
  def first(self) -> Run:
142
168
  """
143
169
  Get the first run in the collection.
@@ -206,9 +232,9 @@ class RunCollection:
206
232
  and exclusive of the upper bound).
207
233
 
208
234
  Args:
209
- config: The configuration object to filter the runs. This can be
210
- any object that provides key-value pairs through the
211
- `iter_params` function.
235
+ config (object | None): The configuration object to filter the runs.
236
+ This can be any object that provides key-value pairs through
237
+ the `iter_params` function.
212
238
  **kwargs: Additional key-value pairs to filter the runs.
213
239
 
214
240
  Returns:
@@ -226,7 +252,7 @@ class RunCollection:
226
252
  is raised.
227
253
 
228
254
  Args:
229
- config: The configuration object to identify the run.
255
+ config (object | None): The configuration object to identify the run.
230
256
  **kwargs: Additional key-value pairs to filter the runs.
231
257
 
232
258
  Returns:
@@ -251,7 +277,7 @@ class RunCollection:
251
277
  returned.
252
278
 
253
279
  Args:
254
- config: The configuration object to identify the run.
280
+ config (object | None): The configuration object to identify the run.
255
281
  **kwargs: Additional key-value pairs to filter the runs.
256
282
 
257
283
  Returns:
@@ -274,7 +300,7 @@ class RunCollection:
274
300
  is raised.
275
301
 
276
302
  Args:
277
- config: The configuration object to identify the run.
303
+ config (object | None): The configuration object to identify the run.
278
304
  **kwargs: Additional key-value pairs to filter the runs.
279
305
 
280
306
  Returns:
@@ -299,7 +325,7 @@ class RunCollection:
299
325
  returned.
300
326
 
301
327
  Args:
302
- config: The configuration object to identify the run.
328
+ config (object | None): The configuration object to identify the run.
303
329
  **kwargs: Additional key-value pairs to filter the runs.
304
330
 
305
331
  Returns:
@@ -322,7 +348,7 @@ class RunCollection:
322
348
  one run matches the criteria, a `ValueError` is raised.
323
349
 
324
350
  Args:
325
- config: The configuration object to identify the run.
351
+ config (object | None): The configuration object to identify the run.
326
352
  **kwargs: Additional key-value pairs to filter the runs.
327
353
 
328
354
  Returns:
@@ -348,7 +374,7 @@ class RunCollection:
348
374
  If more than one run matches the criteria, a `ValueError` is raised.
349
375
 
350
376
  Args:
351
- config: The configuration object to identify the run.
377
+ config (object | None): The configuration object to identify the run.
352
378
  **kwargs: Additional key-value pairs to filter the runs.
353
379
 
354
380
  Returns:
@@ -398,7 +424,8 @@ class RunCollection:
398
424
  results.
399
425
 
400
426
  Args:
401
- func: A function that takes a run and returns a result.
427
+ func (Callable[[Run], T]): A function that takes a run and returns a
428
+ result.
402
429
 
403
430
  Yields:
404
431
  Results obtained by applying the function to each run in the
@@ -412,7 +439,8 @@ class RunCollection:
412
439
  of results.
413
440
 
414
441
  Args:
415
- func: A function that takes a run id and returns a result.
442
+ func (Callable[[str], T]): A function that takes a run id and returns a
443
+ result.
416
444
 
417
445
  Yields:
418
446
  Results obtained by applying the function to each run id in the
@@ -426,8 +454,8 @@ class RunCollection:
426
454
  an iterator of results.
427
455
 
428
456
  Args:
429
- func: A function that takes a run configuration and returns a
430
- result.
457
+ func (Callable[[DictConfig], T]): A function that takes a run
458
+ configuration and returns a result.
431
459
 
432
460
  Yields:
433
461
  Results obtained by applying the function to each run configuration
@@ -445,8 +473,8 @@ class RunCollection:
445
473
  have an artifact URI, None is passed to the function.
446
474
 
447
475
  Args:
448
- func: A function that takes an artifact URI (string or None) and
449
- returns a result.
476
+ func (Callable[[str | None], T]): A function that takes an
477
+ artifact URI (string or None) and returns a result.
450
478
 
451
479
  Yields:
452
480
  Results obtained by applying the function to each artifact URI in the
@@ -464,8 +492,8 @@ class RunCollection:
464
492
  path.
465
493
 
466
494
  Args:
467
- func: A function that takes an artifact directory path (string) and
468
- returns a result.
495
+ func (Callable[[str], T]): A function that takes an artifact directory
496
+ path (string) and returns a result.
469
497
 
470
498
  Yields:
471
499
  Results obtained by applying the function to each artifact directory
@@ -473,6 +501,33 @@ class RunCollection:
473
501
  """
474
502
  return (func(download_artifacts(run_id=run.info.run_id)) for run in self._runs)
475
503
 
504
+ def group_by(
505
+ self, names: list[str] | None = None, *args
506
+ ) -> dict[tuple[str, ...], RunCollection]:
507
+ """
508
+ Group the runs by the specified parameter names and return a dictionary
509
+ where the keys are the parameter values and the values are the runs.
510
+
511
+ Args:
512
+ names (list[str] | None): The parameter names to group by.
513
+ *args: Additional positional arguments to specify parameter names.
514
+
515
+ Returns:
516
+ A dictionary where the keys are the parameter values and the values
517
+ are the runs.
518
+ """
519
+ names = names[:] if names else []
520
+ names.extend(args)
521
+
522
+ grouped_runs = {}
523
+ for run in self._runs:
524
+ key = get_params(run, names)
525
+ if key not in grouped_runs:
526
+ grouped_runs[key] = []
527
+ grouped_runs[key].append(run)
528
+
529
+ return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
530
+
476
531
 
477
532
  def _param_matches(run: Run, key: str, value: Any) -> bool:
478
533
  """
@@ -483,9 +538,9 @@ def _param_matches(run: Run, key: str, value: Any) -> bool:
483
538
  and tuples.
484
539
 
485
540
  Args:
486
- run: The run object to check.
487
- key: The parameter key to check.
488
- value: The parameter value to check.
541
+ run (Run): The run object to check.
542
+ key (str): The parameter key to check.
543
+ value (Any): The parameter value to check.
489
544
 
490
545
  Returns:
491
546
  True if the run's parameter matches the specified key-value pair,
@@ -526,10 +581,10 @@ def filter_runs(runs: list[Run], config: object | None = None, **kwargs) -> list
526
581
  exclusive of the upper bound).
527
582
 
528
583
  Args:
529
- runs: The list of runs to filter.
530
- config: The configuration object to filter the runs. This can be any
531
- object that provides key-value pairs through the `iter_params`
532
- function.
584
+ runs (list[Run]): The list of runs to filter.
585
+ config (object | None): The configuration object to filter the runs.
586
+ This can be any object that provides key-value pairs through the
587
+ `iter_params` function.
533
588
  **kwargs: Additional key-value pairs to filter the runs.
534
589
 
535
590
  Returns:
@@ -554,8 +609,8 @@ def find_run(runs: list[Run], config: object | None = None, **kwargs) -> Run:
554
609
  raised.
555
610
 
556
611
  Args:
557
- runs: The runs to filter.
558
- config: The configuration object to identify the run.
612
+ runs (list[Run]): The runs to filter.
613
+ config (object | None): The configuration object to identify the run.
559
614
  **kwargs: Additional key-value pairs to filter the runs.
560
615
 
561
616
  Returns:
@@ -584,8 +639,8 @@ def try_find_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
584
639
  the provided parameters. If no run matches the criteria, None is returned.
585
640
 
586
641
  Args:
587
- runs: The runs to filter.
588
- config: The configuration object to identify the run.
642
+ runs (list[Run]): The runs to filter.
643
+ config (object | None): The configuration object to identify the run.
589
644
  **kwargs: Additional key-value pairs to filter the runs.
590
645
 
591
646
  Returns:
@@ -610,8 +665,8 @@ def find_last_run(runs: list[Run], config: object | None = None, **kwargs) -> Ru
610
665
  is raised.
611
666
 
612
667
  Args:
613
- runs: The runs to filter.
614
- config: The configuration object to identify the run.
668
+ runs (list[Run]): The runs to filter.
669
+ config (object | None): The configuration object to identify the run.
615
670
  **kwargs: Additional key-value pairs to filter the runs.
616
671
 
617
672
  Returns:
@@ -641,8 +696,8 @@ def try_find_last_run(runs: list[Run], config: object | None = None, **kwargs) -
641
696
  the provided parameters. If no run matches the criteria, None is returned.
642
697
 
643
698
  Args:
644
- runs: The runs to filter.
645
- config: The configuration object to identify the run.
699
+ runs (list[Run]): The runs to filter.
700
+ config (object | None): The configuration object to identify the run.
646
701
  **kwargs: Additional key-value pairs to filter the runs.
647
702
 
648
703
  Returns:
@@ -667,8 +722,8 @@ def get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run:
667
722
  than one run matches the criteria, a `ValueError` is raised.
668
723
 
669
724
  Args:
670
- runs: The runs to filter.
671
- config: The configuration object to identify the run.
725
+ runs (list[Run]): The runs to filter.
726
+ config (object | None): The configuration object to identify the run.
672
727
  **kwargs: Additional key-value pairs to filter the runs.
673
728
 
674
729
  Returns:
@@ -707,8 +762,8 @@ def try_get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
707
762
  If more than one run matches the criteria, a `ValueError` is raised.
708
763
 
709
764
  Args:
710
- runs: The runs to filter.
711
- config: The configuration object to identify the run.
765
+ runs (list[Run]): The runs to filter.
766
+ config (object | None): The configuration object to identify the run.
712
767
  **kwargs: Additional key-value pairs to filter the runs.
713
768
 
714
769
  Returns:
@@ -737,6 +792,13 @@ def try_get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
737
792
  raise ValueError(msg)
738
793
 
739
794
 
795
+ def get_params(run: Run, names: list[str] | None = None, *args) -> tuple[str, ...]:
796
+ names = names[:] if names else []
797
+ names.extend(args)
798
+
799
+ return tuple(run.data.params[name] for name in names)
800
+
801
+
740
802
  def get_param_names(runs: list[Run]) -> list[str]:
741
803
  """
742
804
  Get the parameter names from the runs.
@@ -746,7 +808,7 @@ def get_param_names(runs: list[Run]) -> list[str]:
746
808
  set to ensure uniqueness.
747
809
 
748
810
  Args:
749
- runs: The list of runs from which to extract parameter names.
811
+ runs (list[Run]): The list of runs from which to extract parameter names.
750
812
 
751
813
  Returns:
752
814
  A list of unique parameter names.
@@ -770,7 +832,8 @@ def get_param_dict(runs: list[Run]) -> dict[str, list[str]]:
770
832
  and the values are lists of parameter values.
771
833
 
772
834
  Args:
773
- runs: The list of runs from which to extract parameter names and values.
835
+ runs (list[Run]): The list of runs from which to extract parameter names
836
+ and values.
774
837
 
775
838
  Returns:
776
839
  A dictionary where the keys are parameter names and the values are lists
@@ -795,7 +858,7 @@ def load_config(run: Run) -> DictConfig:
795
858
  `.hydra/config.yaml` is not found in the run's artifact directory.
796
859
 
797
860
  Args:
798
- run: The Run instance for which to load the configuration.
861
+ run (Run): The Run instance for which to load the configuration.
799
862
 
800
863
  Returns:
801
864
  The loaded configuration as a DictConfig object. Returns an empty
@@ -813,37 +876,3 @@ def _load_config(run_id: str) -> DictConfig:
813
876
  return DictConfig({})
814
877
 
815
878
  return OmegaConf.load(path) # type: ignore
816
-
817
-
818
- # def get_hydra_output_dir(run: Run_ | Series | str) -> Path:
819
- # """
820
- # Get the Hydra output directory.
821
-
822
- # Args:
823
- # run: The run object.
824
-
825
- # Returns:
826
- # Path: The Hydra output directory.
827
- # """
828
- # path = get_artifact_dir(run) / ".hydra/hydra.yaml"
829
-
830
- # if path.exists():
831
- # hc = OmegaConf.load(path)
832
- # return Path(hc.hydra.runtime.output_dir)
833
-
834
- # raise FileNotFoundError
835
-
836
-
837
- # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
838
- # """
839
- # Log the Hydra output directory.
840
-
841
- # Args:
842
- # run: The run object.
843
-
844
- # Returns:
845
- # None
846
- # """
847
- # output_dir = get_hydra_output_dir(run)
848
- # run_id = run if isinstance(run, str) else run.info.run_id
849
- # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
@@ -7,7 +7,7 @@ import hydra
7
7
  import mlflow
8
8
  from hydra.core.config_store import ConfigStore
9
9
 
10
- from hydraflow.context import log_run
10
+ import hydraflow
11
11
 
12
12
  log = logging.getLogger(__name__)
13
13
 
@@ -25,11 +25,13 @@ cs.store(name="config", node=MySQLConfig)
25
25
  @hydra.main(version_base=None, config_name="config")
26
26
  def app(cfg: MySQLConfig):
27
27
  mlflow.set_experiment("log_run")
28
- with mlflow.start_run(), log_run(cfg) as info:
28
+ with hydraflow.start_run(cfg):
29
+ artifact_dir = hydraflow.get_artifact_dir()
30
+ output_dir = hydraflow.get_hydra_output_dir()
29
31
  log.info(f"START, {cfg.host}, {cfg.port} ")
30
- mlflow.log_text("A " + info.artifact_dir.as_posix(), "artifact_dir.txt")
31
- mlflow.log_text("B " + info.output_dir.as_posix(), "output_dir.txt")
32
- (info.artifact_dir / "a.txt").write_text("abc")
32
+ mlflow.log_text("A " + artifact_dir.as_posix(), "artifact_dir.txt")
33
+ mlflow.log_text("B " + output_dir.as_posix(), "output_dir.txt")
34
+ (artifact_dir / "a.txt").write_text("abc")
33
35
  log.info("END")
34
36
 
35
37
 
@@ -0,0 +1,80 @@
1
+ from unittest.mock import MagicMock, patch
2
+
3
+ import mlflow
4
+ import pytest
5
+
6
+ from hydraflow.context import log_run, start_run, watch
7
+ from hydraflow.runs import RunCollection
8
+
9
+
10
+ @pytest.fixture
11
+ def runs(monkeypatch, tmp_path):
12
+ from hydraflow.runs import list_runs
13
+
14
+ monkeypatch.chdir(tmp_path)
15
+
16
+ with (
17
+ patch("hydraflow.context.HydraConfig.get") as mock_hydra_config,
18
+ patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
19
+ ):
20
+ mock_hydra_config.return_value.runtime.output_dir = "/tmp"
21
+ mock_log_artifacts.return_value = None
22
+
23
+ mlflow.set_experiment("test_run")
24
+ for x in range(3):
25
+ cfg = {"x": x, "l": [x, x, x], "d": {"i": x}}
26
+ with start_run(cfg):
27
+ mlflow.log_param("y", x)
28
+
29
+ return list_runs(["test_run"])
30
+
31
+
32
+ def test_runs_len(runs: RunCollection):
33
+ assert len(runs) == 3
34
+
35
+
36
+ @pytest.mark.parametrize("i", [0, 1, 2])
37
+ @pytest.mark.parametrize("n", ["x", "y"])
38
+ def test_runs_params(runs: RunCollection, i: int, n: str):
39
+ assert runs[i].data.params[n] == str(i)
40
+
41
+
42
+ @pytest.mark.parametrize("i", [0, 1, 2])
43
+ def test_runs_params_list(runs: RunCollection, i: int):
44
+ assert runs[i].data.params["l"] == f"[{i}, {i}, {i}]"
45
+
46
+
47
+ @pytest.mark.parametrize("i", [0, 1, 2])
48
+ def test_runs_params_dict(runs: RunCollection, i: int):
49
+ assert runs[i].data.params["d.i"] == str(i)
50
+
51
+
52
+ def test_log_run_error_handling():
53
+ config = MagicMock()
54
+ config.some_param = "value"
55
+
56
+ with (
57
+ patch("hydraflow.context.log_params") as mock_log_params,
58
+ patch("hydraflow.context.HydraConfig.get") as mock_hydra_config,
59
+ patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
60
+ ):
61
+ mock_log_params.side_effect = Exception("Test exception")
62
+ mock_hydra_config.return_value.runtime.output_dir = "/tmp"
63
+ mock_log_artifacts.return_value = None
64
+
65
+ with pytest.raises(Exception, match="Test exception"):
66
+ with log_run(config):
67
+ pass
68
+
69
+
70
+ def test_watch_error_handling():
71
+ func = MagicMock()
72
+ dir = "/tmp"
73
+
74
+ with patch("hydraflow.context.Observer") as mock_observer:
75
+ mock_observer_instance = mock_observer.return_value
76
+ mock_observer_instance.start.side_effect = Exception("Test exception")
77
+
78
+ with pytest.raises(Exception, match="Test exception"):
79
+ with watch(func, dir):
80
+ pass
@@ -26,7 +26,7 @@ def runs(monkeypatch, tmp_path):
26
26
 
27
27
  @pytest.fixture(params=range(4))
28
28
  def run(runs, request):
29
- run = runs[request.param]
29
+ run = runs[request.param] # type: ignore
30
30
  assert isinstance(run, Run)
31
31
  return run
32
32
 
@@ -396,6 +396,44 @@ def test_run_collection_map_dir(runs: RunCollection):
396
396
  assert all(isinstance(dir_path, str) for dir_path in results)
397
397
 
398
398
 
399
+ def test_run_collection_sort(runs: RunCollection):
400
+ runs.sort(key=lambda x: x.data.params["p"])
401
+ assert [run.data.params["p"] for run in runs] == ["0", "1", "2", "3", "4", "5"]
402
+
403
+ runs.sort(reverse=True)
404
+ assert [run.data.params["p"] for run in runs] == ["5", "4", "3", "2", "1", "0"]
405
+
406
+
407
+ def test_run_collection_iter(runs: RunCollection):
408
+ assert list(runs) == runs._runs
409
+
410
+
411
+ @pytest.mark.parametrize("i", range(6))
412
+ def test_run_collection_getitem(runs: RunCollection, i: int):
413
+ assert runs[i] == runs._runs[i]
414
+
415
+
416
+ @pytest.mark.parametrize("i", range(6))
417
+ def test_run_collection_contains(runs: RunCollection, i: int):
418
+ assert runs[i] in runs
419
+ assert runs._runs[i] in runs
420
+
421
+
422
+ def test_run_collection_group_by(runs: RunCollection):
423
+ grouped = runs.group_by(["p"])
424
+ assert len(grouped) == 6
425
+ assert all(isinstance(group, RunCollection) for group in grouped.values())
426
+ assert all(len(group) == 1 for group in grouped.values())
427
+ assert grouped[("0",)][0] == runs[0]
428
+ assert grouped[("1",)][0] == runs[1]
429
+
430
+ grouped = runs.group_by(["q"])
431
+ assert len(grouped) == 2
432
+
433
+ grouped = runs.group_by(["r"])
434
+ assert len(grouped) == 3
435
+
436
+
399
437
  # def test_hydra_output_dir_error(runs_list: list[Run]):
400
438
  # from hydraflow.runs import get_hydra_output_dir
401
439
 
@@ -1,30 +0,0 @@
1
- from .context import Info, chdir_artifact, log_run, watch
2
- from .mlflow import set_experiment
3
- from .runs import (
4
- Run,
5
- RunCollection,
6
- filter_runs,
7
- get_param_dict,
8
- get_param_names,
9
- get_run,
10
- list_runs,
11
- load_config,
12
- search_runs,
13
- )
14
-
15
- __all__ = [
16
- "Info",
17
- "Run",
18
- "RunCollection",
19
- "chdir_artifact",
20
- "filter_runs",
21
- "get_param_dict",
22
- "get_param_names",
23
- "get_run",
24
- "list_runs",
25
- "load_config",
26
- "log_run",
27
- "search_runs",
28
- "set_experiment",
29
- "watch",
30
- ]
@@ -1,72 +0,0 @@
1
- """
2
- This module provides functionality to log parameters from Hydra
3
- configuration objects and set up experiments using MLflow.
4
- """
5
-
6
- from __future__ import annotations
7
-
8
- from pathlib import Path
9
-
10
- import mlflow
11
- from hydra.core.hydra_config import HydraConfig
12
-
13
- from hydraflow.config import iter_params
14
-
15
-
16
- def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -> None:
17
- """
18
- Set the experiment name and tracking URI optionally.
19
-
20
- This function sets the experiment name by combining the given prefix,
21
- the job name from HydraConfig, and the given suffix. Optionally, it can
22
- also set the tracking URI.
23
-
24
- Args:
25
- prefix: The prefix to prepend to the experiment name.
26
- suffix: The suffix to append to the experiment name.
27
- uri: The tracking URI to use.
28
- """
29
- if uri:
30
- mlflow.set_tracking_uri(uri)
31
-
32
- hc = HydraConfig.get()
33
- name = f"{prefix}{hc.job.name}{suffix}"
34
- mlflow.set_experiment(name)
35
-
36
-
37
- def log_params(config: object, *, synchronous: bool | None = None) -> None:
38
- """
39
- Log the parameters from the given configuration object.
40
-
41
- This method logs the parameters from the provided configuration object
42
- using MLflow. It iterates over the parameters and logs them using the
43
- `mlflow.log_param` method.
44
-
45
- Args:
46
- config: The configuration object to log the parameters from.
47
- synchronous: Whether to log the parameters synchronously.
48
- Defaults to None.
49
- """
50
- for key, value in iter_params(config):
51
- mlflow.log_param(key, value, synchronous=synchronous)
52
-
53
-
54
- def get_artifact_dir(artifact_path: str | None = None) -> Path:
55
- """
56
- Get the artifact directory for the given artifact path.
57
-
58
- This function retrieves the artifact URI for the specified artifact path
59
- using MLflow, downloads the artifacts to a local directory, and returns
60
- the path to that directory.
61
-
62
- Args:
63
- artifact_path: The artifact path for which to get the directory.
64
- Defaults to None.
65
-
66
- Returns:
67
- The local path to the directory where the artifacts are downloaded.
68
- """
69
- uri = mlflow.get_artifact_uri(artifact_path)
70
- dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
71
-
72
- return Path(dir)
@@ -1,36 +0,0 @@
1
- from unittest.mock import MagicMock, patch
2
-
3
- import pytest
4
-
5
- from hydraflow.context import log_run, watch
6
-
7
-
8
- def test_log_run_error_handling():
9
- config = MagicMock()
10
- config.some_param = "value"
11
-
12
- with (
13
- patch("hydraflow.context.log_params") as mock_log_params,
14
- patch("hydraflow.context.HydraConfig.get") as mock_hydra_config,
15
- patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
16
- ):
17
- mock_log_params.side_effect = Exception("Test exception")
18
- mock_hydra_config.return_value.runtime.output_dir = "/tmp"
19
- mock_log_artifacts.return_value = None
20
-
21
- with pytest.raises(Exception, match="Test exception"):
22
- with log_run(config):
23
- pass
24
-
25
-
26
- def test_watch_error_handling():
27
- func = MagicMock()
28
- dir = "/tmp"
29
-
30
- with patch("hydraflow.context.Observer") as mock_observer:
31
- mock_observer_instance = mock_observer.return_value
32
- mock_observer_instance.start.side_effect = Exception("Test exception")
33
-
34
- with pytest.raises(Exception, match="Test exception"):
35
- with watch(func, dir):
36
- pass
File without changes
File without changes
File without changes
File without changes