hydraflow 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
hydraflow/__init__.py CHANGED
@@ -1,11 +1,11 @@
1
1
  from .context import chdir_artifact, log_run, start_run, watch
2
- from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
3
- from .runs import (
4
- RunCollection,
2
+ from .info import get_artifact_dir, get_hydra_output_dir, load_config
3
+ from .mlflow import (
5
4
  list_runs,
6
- load_config,
7
5
  search_runs,
6
+ set_experiment,
8
7
  )
8
+ from .run_collection import RunCollection
9
9
 
10
10
  __all__ = [
11
11
  "RunCollection",
hydraflow/context.py CHANGED
@@ -14,10 +14,11 @@ from typing import TYPE_CHECKING
14
14
 
15
15
  import mlflow
16
16
  from hydra.core.hydra_config import HydraConfig
17
- from watchdog.events import FileModifiedEvent, FileSystemEventHandler
17
+ from watchdog.events import FileModifiedEvent, PatternMatchingEventHandler
18
18
  from watchdog.observers import Observer
19
19
 
20
- from hydraflow.mlflow import get_artifact_dir, log_params
20
+ from hydraflow.info import get_artifact_dir
21
+ from hydraflow.mlflow import log_params
21
22
 
22
23
  if TYPE_CHECKING:
23
24
  from collections.abc import Callable, Iterator
@@ -68,7 +69,7 @@ def log_run(
68
69
  mlflow.log_artifact(local_path)
69
70
 
70
71
  try:
71
- with watch(log_artifact, output_dir):
72
+ with watch(log_artifact, output_dir, ignore_log=False):
72
73
  yield
73
74
 
74
75
  except Exception as e:
@@ -140,9 +141,11 @@ def start_run(
140
141
 
141
142
  @contextmanager
142
143
  def watch(
143
- func: Callable[[Path], None],
144
+ callback: Callable[[Path], None],
144
145
  dir: Path | str = "",
145
146
  timeout: int = 60,
147
+ ignore_patterns: list[str] | None = None,
148
+ ignore_log: bool = True,
146
149
  ) -> Iterator[None]:
147
150
  """
148
151
  Watch the given directory for changes and call the provided function
@@ -154,7 +157,7 @@ def watch(
154
157
  period or until the context is exited.
155
158
 
156
159
  Args:
157
- func (Callable[[Path], None]): The function to call when a change is
160
+ callback (Callable[[Path], None]): The function to call when a change is
158
161
  detected. It should accept a single argument of type `Path`,
159
162
  which is the path of the modified file.
160
163
  dir (Path | str): The directory to watch. If not specified,
@@ -174,7 +177,7 @@ def watch(
174
177
  if isinstance(dir, Path):
175
178
  dir = dir.as_posix()
176
179
 
177
- handler = Handler(func)
180
+ handler = Handler(callback, ignore_patterns=ignore_patterns, ignore_log=ignore_log)
178
181
  observer = Observer()
179
182
  observer.schedule(handler, dir, recursive=True)
180
183
  observer.start()
@@ -198,10 +201,23 @@ def watch(
198
201
  observer.join()
199
202
 
200
203
 
201
- class Handler(FileSystemEventHandler):
202
- def __init__(self, func: Callable[[Path], None]) -> None:
204
+ class Handler(PatternMatchingEventHandler):
205
+ def __init__(
206
+ self,
207
+ func: Callable[[Path], None],
208
+ ignore_patterns: list[str] | None = None,
209
+ ignore_log: bool = True,
210
+ ) -> None:
203
211
  self.func = func
204
212
 
213
+ if ignore_log:
214
+ if ignore_patterns:
215
+ ignore_patterns.append("*.log")
216
+ else:
217
+ ignore_patterns = ["*.log"]
218
+
219
+ super().__init__(ignore_patterns=ignore_patterns)
220
+
205
221
  def on_modified(self, event: FileModifiedEvent) -> None:
206
222
  file = Path(str(event.src_path))
207
223
  if file.is_file():
hydraflow/info.py ADDED
@@ -0,0 +1,116 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING
5
+
6
+ import mlflow
7
+ from hydra.core.hydra_config import HydraConfig
8
+ from mlflow.tracking import artifact_utils
9
+ from omegaconf import DictConfig, OmegaConf
10
+
11
+ if TYPE_CHECKING:
12
+ from mlflow.entities import Run
13
+
14
+ from hydraflow.run_collection import RunCollection
15
+
16
+
17
+ class RunCollectionInfo:
18
+ def __init__(self, runs: RunCollection):
19
+ self._runs = runs
20
+
21
+ @property
22
+ def run_id(self) -> list[str]:
23
+ return [run.info.run_id for run in self._runs]
24
+
25
+ @property
26
+ def params(self) -> list[dict[str, str]]:
27
+ return [run.data.params for run in self._runs]
28
+
29
+ @property
30
+ def metrics(self) -> list[dict[str, float]]:
31
+ return [run.data.metrics for run in self._runs]
32
+
33
+ @property
34
+ def artifact_uri(self) -> list[str | None]:
35
+ return [run.info.artifact_uri for run in self._runs]
36
+
37
+ @property
38
+ def artifact_dir(self) -> list[Path]:
39
+ return [get_artifact_dir(run) for run in self._runs]
40
+
41
+ @property
42
+ def config(self) -> list[DictConfig]:
43
+ return [load_config(run) for run in self._runs]
44
+
45
+
46
+ def get_artifact_dir(run: Run | None = None) -> Path:
47
+ """
48
+ Retrieve the artifact directory for the given run.
49
+
50
+ This function uses MLflow to get the artifact directory for the given run.
51
+
52
+ Args:
53
+ run (Run | None): The run object. Defaults to None.
54
+
55
+ Returns:
56
+ The local path to the directory where the artifacts are downloaded.
57
+ """
58
+ if run is None:
59
+ uri = mlflow.get_artifact_uri()
60
+ else:
61
+ uri = artifact_utils.get_artifact_uri(run.info.run_id)
62
+
63
+ return Path(mlflow.artifacts.download_artifacts(uri))
64
+
65
+
66
+ def get_hydra_output_dir(run: Run | None = None) -> Path:
67
+ """
68
+ Retrieve the Hydra output directory for the given run.
69
+
70
+ This function returns the Hydra output directory. If no run is provided,
71
+ it retrieves the output directory from the current Hydra configuration.
72
+ If a run is provided, it retrieves the artifact path for the run, loads
73
+ the Hydra configuration from the downloaded artifacts, and returns the
74
+ output directory specified in that configuration.
75
+
76
+ Args:
77
+ run (Run | None): The run object. Defaults to None.
78
+
79
+ Returns:
80
+ Path: The path to the Hydra output directory.
81
+
82
+ Raises:
83
+ FileNotFoundError: If the Hydra configuration file is not found
84
+ in the artifacts.
85
+ """
86
+ if run is None:
87
+ hc = HydraConfig.get()
88
+ return Path(hc.runtime.output_dir)
89
+
90
+ path = get_artifact_dir(run) / ".hydra/hydra.yaml"
91
+
92
+ if path.exists():
93
+ hc = OmegaConf.load(path)
94
+ return Path(hc.hydra.runtime.output_dir)
95
+
96
+ raise FileNotFoundError
97
+
98
+
99
+ def load_config(run: Run) -> DictConfig:
100
+ """
101
+ Load the configuration for a given run.
102
+
103
+ This function loads the configuration for the provided Run instance
104
+ by downloading the configuration file from the MLflow artifacts and
105
+ loading it using OmegaConf. It returns an empty config if
106
+ `.hydra/config.yaml` is not found in the run's artifact directory.
107
+
108
+ Args:
109
+ run (Run): The Run instance for which to load the configuration.
110
+
111
+ Returns:
112
+ The loaded configuration as a DictConfig object. Returns an empty
113
+ DictConfig if the configuration file is not found.
114
+ """
115
+ path = get_artifact_dir(run) / ".hydra/config.yaml"
116
+ return OmegaConf.load(path) # type: ignore
hydraflow/mlflow.py CHANGED
@@ -1,6 +1,20 @@
1
1
  """
2
- This module provides functionality to log parameters from Hydra
3
- configuration objects and set up experiments using MLflow.
2
+ This module provides functionality to log parameters from Hydra configuration objects
3
+ and set up experiments using MLflow. It includes methods for managing experiments,
4
+ searching for runs, and logging parameters and artifacts.
5
+
6
+ Key Features:
7
+ - **Experiment Management**: Set and manage MLflow experiments with customizable names
8
+ based on Hydra configuration.
9
+ - **Run Logging**: Log parameters and metrics from Hydra configuration objects to
10
+ MLflow, ensuring that all relevant information is captured during experiments.
11
+ - **Run Search**: Search for runs based on various criteria, allowing for flexible
12
+ retrieval of experiment results.
13
+ - **Artifact Management**: Retrieve and log artifacts associated with runs, facilitating
14
+ easy access to outputs generated during experiments.
15
+
16
+ This module is designed to integrate seamlessly with Hydra, providing a robust
17
+ solution for tracking machine learning experiments and their associated metadata.
4
18
  """
5
19
 
6
20
  from __future__ import annotations
@@ -10,10 +24,11 @@ from typing import TYPE_CHECKING
10
24
 
11
25
  import mlflow
12
26
  from hydra.core.hydra_config import HydraConfig
13
- from mlflow.tracking import artifact_utils
14
- from omegaconf import OmegaConf
27
+ from mlflow.entities import ViewType
28
+ from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS
15
29
 
16
30
  from hydraflow.config import iter_params
31
+ from hydraflow.run_collection import RunCollection
17
32
 
18
33
  if TYPE_CHECKING:
19
34
  from mlflow.entities.experiment import Experiment
@@ -25,7 +40,7 @@ def set_experiment(
25
40
  uri: str | Path | None = None,
26
41
  ) -> Experiment:
27
42
  """
28
- Set the experiment name and tracking URI optionally.
43
+ Sets the experiment name and tracking URI optionally.
29
44
 
30
45
  This function sets the experiment name by combining the given prefix,
31
46
  the job name from HydraConfig, and the given suffix. Optionally, it can
@@ -65,60 +80,96 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
65
80
  mlflow.log_param(key, value, synchronous=synchronous)
66
81
 
67
82
 
68
- def get_artifact_dir(
69
- artifact_path: str | None = None,
70
- *,
71
- run_id: str | None = None,
72
- ) -> Path:
83
+ def search_runs(
84
+ experiment_ids: list[str] | None = None,
85
+ filter_string: str = "",
86
+ run_view_type: int = ViewType.ACTIVE_ONLY,
87
+ max_results: int = SEARCH_MAX_RESULTS_PANDAS,
88
+ order_by: list[str] | None = None,
89
+ search_all_experiments: bool = False,
90
+ experiment_names: list[str] | None = None,
91
+ ) -> RunCollection:
73
92
  """
74
- Get the artifact directory for the given artifact path.
93
+ Search for Runs that fit the specified criteria.
75
94
 
76
- This function retrieves the artifact URI for the specified artifact path
77
- using MLflow, downloads the artifacts to a local directory, and returns
78
- the path to that directory.
95
+ This function wraps the `mlflow.search_runs` function and returns the
96
+ results as a `RunCollection` object. It allows for flexible searching of
97
+ MLflow runs based on various criteria.
98
+
99
+ Note:
100
+ The returned runs are sorted by their start time in ascending order.
79
101
 
80
102
  Args:
81
- artifact_path (str | None): The artifact path for which to get the
82
- directory. Defaults to None.
83
- run_id (str | None): The run ID for which to get the artifact directory.
103
+ experiment_ids (list[str] | None): List of experiment IDs. Search can
104
+ work with experiment IDs or experiment names, but not both in the
105
+ same call. Values other than ``None`` or ``[]`` will result in
106
+ error if ``experiment_names`` is also not ``None`` or ``[]``.
107
+ ``None`` will default to the active experiment if ``experiment_names``
108
+ is ``None`` or ``[]``.
109
+ filter_string (str): Filter query string, defaults to searching all
110
+ runs.
111
+ run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
112
+ or ``ALL`` runs defined in :py:class:`mlflow.entities.ViewType`.
113
+ max_results (int): The maximum number of runs to put in the dataframe.
114
+ Default is 100,000 to avoid causing out-of-memory issues on the user's
115
+ machine.
116
+ order_by (list[str] | None): List of columns to order by (e.g.,
117
+ "metrics.rmse"). The ``order_by`` column can contain an optional
118
+ ``DESC`` or ``ASC`` value. The default is ``ASC``. The default
119
+ ordering is to sort by ``start_time DESC``, then ``run_id``.
120
+ ``start_time DESC``, then ``run_id``.
121
+ search_all_experiments (bool): Boolean specifying whether all
122
+ experiments should be searched. Only honored if ``experiment_ids``
123
+ is ``[]`` or ``None``.
124
+ experiment_names (list[str] | None): List of experiment names. Search
125
+ can work with experiment IDs or experiment names, but not both in
126
+ the same call. Values other than ``None`` or ``[]`` will result in
127
+ error if ``experiment_ids`` is also not ``None`` or ``[]``.
128
+ ``experiment_ids`` is also not ``None`` or ``[]``. ``None`` will
129
+ default to the active experiment if ``experiment_ids`` is ``None``
130
+ or ``[]``.
84
131
 
85
132
  Returns:
86
- The local path to the directory where the artifacts are downloaded.
133
+ A `RunCollection` object containing the search results.
87
134
  """
88
- if run_id is None:
89
- uri = mlflow.get_artifact_uri(artifact_path)
90
- else:
91
- uri = artifact_utils.get_artifact_uri(run_id, artifact_path)
92
-
93
- dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
94
-
95
- return Path(dir)
96
-
97
-
98
- def get_hydra_output_dir(*, run_id: str | None = None) -> Path:
99
- if run_id is None:
100
- hc = HydraConfig.get()
101
- return Path(hc.runtime.output_dir)
102
-
103
- path = get_artifact_dir(run_id=run_id) / ".hydra/hydra.yaml"
104
-
105
- if path.exists():
106
- hc = OmegaConf.load(path)
107
- return Path(hc.hydra.runtime.output_dir)
135
+ runs = mlflow.search_runs(
136
+ experiment_ids=experiment_ids,
137
+ filter_string=filter_string,
138
+ run_view_type=run_view_type,
139
+ max_results=max_results,
140
+ order_by=order_by,
141
+ output_format="list",
142
+ search_all_experiments=search_all_experiments,
143
+ experiment_names=experiment_names,
144
+ )
145
+ runs = sorted(runs, key=lambda run: run.info.start_time) # type: ignore
146
+ return RunCollection(runs) # type: ignore
147
+
148
+
149
+ def list_runs(experiment_names: list[str] | None = None) -> RunCollection:
150
+ """
151
+ List all runs for the specified experiments.
108
152
 
109
- raise FileNotFoundError
153
+ This function retrieves all runs for the given list of experiment names.
154
+ If no experiment names are provided (None), it defaults to searching all runs
155
+ for the currently active experiment. If an empty list is provided, the function
156
+ will search all runs for all experiments except the "Default" experiment.
157
+ The function returns the results as a `RunCollection` object.
110
158
 
159
+ Note:
160
+ The returned runs are sorted by their start time in ascending order.
111
161
 
112
- # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
113
- # """
114
- # Log the Hydra output directory.
162
+ Args:
163
+ experiment_names (list[str] | None): List of experiment names to search
164
+ for runs. If None or an empty list is provided, the function will
165
+ search the currently active experiment or all experiments except
166
+ the "Default" experiment.
115
167
 
116
- # Args:
117
- # run: The run object.
168
+ Returns:
169
+ A `RunCollection` object containing the runs for the specified experiments.
170
+ """
171
+ if experiment_names == []:
172
+ experiments = mlflow.search_experiments()
173
+ experiment_names = [e.name for e in experiments if e.name != "Default"]
118
174
 
119
- # Returns:
120
- # None
121
- # """
122
- # output_dir = get_hydra_output_dir(run)
123
- # run_id = run if isinstance(run, str) else run.info.run_id
124
- # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
175
+ return search_runs(experiment_names=experiment_names)
hydraflow/progress.py CHANGED
@@ -3,27 +3,57 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import joblib
6
- from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
6
+ from rich.progress import Progress
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  from collections.abc import Iterable
10
10
 
11
+ from rich.progress import ProgressColumn
11
12
 
12
- def progress(
13
- *iterables: Iterable[int | tuple[int, int]],
13
+
14
+ def multi_task_progress(
15
+ iterables: Iterable[Iterable[int | tuple[int, int]]],
16
+ *columns: ProgressColumn | str,
14
17
  n_jobs: int = -1,
15
- task_name: str = "#{:0>3}",
16
- main_task_name: str = "main",
18
+ description: str = "#{:0>3}",
19
+ main_description: str = "main",
20
+ transient: bool | None = None,
21
+ **kwargs,
17
22
  ) -> None:
18
- with Progress(
19
- SpinnerColumn(),
20
- *Progress.get_default_columns(),
21
- TimeElapsedColumn(),
22
- ) as progress:
23
+ """
24
+ Render auto-updating progress bars for multiple tasks concurrently.
25
+
26
+ Args:
27
+ iterables (Iterable[Iterable[int | tuple[int, int]]]): A collection of
28
+ iterables, each representing a task. Each iterable can yield
29
+ integers (completed) or tuples of integers (completed, total).
30
+ *columns (ProgressColumn | str): Additional columns to display in the
31
+ progress bars.
32
+ n_jobs (int, optional): Number of jobs to run in parallel. Defaults to
33
+ -1, which means using all processors.
34
+ description (str, optional): Format string for describing tasks. Defaults to
35
+ "#{:0>3}".
36
+ main_description (str, optional): Description for the main task.
37
+ Defaults to "main".
38
+ transient (bool | None, optional): Whether to remove the progress bar
39
+ after completion. Defaults to None.
40
+ **kwargs: Additional keyword arguments passed to the Progress instance.
41
+
42
+ Returns:
43
+ None
44
+ """
45
+ if not columns:
46
+ columns = Progress.get_default_columns()
47
+
48
+ iterables = list(iterables)
49
+
50
+ with Progress(*columns, transient=transient or False, **kwargs) as progress:
23
51
  n = len(iterables)
24
52
 
25
- task_main = progress.add_task(main_task_name, total=None) if n > 1 else None
26
- tasks = [progress.add_task(task_name.format(i), start=False, total=None) for i in range(n)]
53
+ task_main = progress.add_task(main_description, total=None) if n > 1 else None
54
+ tasks = [
55
+ progress.add_task(description.format(i), start=False, total=None) for i in range(n)
56
+ ]
27
57
 
28
58
  total = {}
29
59
  completed = {}
@@ -48,9 +78,54 @@ def progress(
48
78
  c = sum(completed.values())
49
79
  progress.update(task_main, total=t, completed=c)
50
80
 
81
+ if transient or n > 1:
82
+ progress.remove_task(tasks[i])
83
+
51
84
  if n > 1:
52
85
  it = (joblib.delayed(func)(i) for i in range(n))
53
86
  joblib.Parallel(n_jobs, prefer="threads")(it)
54
87
 
55
88
  else:
56
89
  func(0)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ import random
94
+ import time
95
+
96
+ from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn, TimeElapsedColumn
97
+
98
+ from hydraflow.progress import multi_task_progress
99
+
100
+ def task(total):
101
+ for i in range(total or 90):
102
+ if total is None:
103
+ yield i
104
+ else:
105
+ yield i, total
106
+ time.sleep(random.random() / 30)
107
+
108
+ def multi_task_progress_test(unknown_total: bool):
109
+ tasks = [task(random.randint(80, 100)) for _ in range(4)]
110
+ if unknown_total:
111
+ tasks = [task(None), *tasks, task(None)]
112
+
113
+ columns = [
114
+ SpinnerColumn(),
115
+ *Progress.get_default_columns(),
116
+ MofNCompleteColumn(),
117
+ TimeElapsedColumn(),
118
+ ]
119
+
120
+ kwargs = {}
121
+ if unknown_total:
122
+ kwargs["main_description"] = "unknown"
123
+
124
+ multi_task_progress(tasks, *columns, n_jobs=4, **kwargs)
125
+
126
+ multi_task_progress_test(False)
127
+ multi_task_progress_test(True)
128
+ multi_task_progress([task(100)])
129
+ multi_task_progress([task(None)], description="unknown")
130
+ multi_task_progress([task(100), task(None)], main_description="transient", transient=True)
131
+ multi_task_progress([task(100)], description="transient", transient=True)
@@ -1,126 +1,47 @@
1
1
  """
2
- This module provides functionality for managing and interacting with MLflow
3
- runs. It includes the `RunCollection` class and various methods to filter
4
- runs, retrieve run information, log artifacts, and load configurations.
2
+ This module provides functionality for managing and interacting with MLflow runs.
3
+ It includes the `RunCollection` class, which serves as a container for multiple MLflow
4
+ run objects, and various methods to filter, retrieve, and manipulate these runs.
5
+
6
+ Key Features:
7
+ - **Run Management**: The `RunCollection` class allows for easy management of multiple
8
+ MLflow runs, providing methods to access, filter, and sort runs based on various
9
+ criteria.
10
+ - **Filtering**: The module supports filtering runs based on specific configurations
11
+ and parameters, enabling users to easily find runs that match certain conditions.
12
+ - **Retrieval**: Users can retrieve specific runs, including the first, last, or any
13
+ run that matches a given configuration.
14
+ - **Artifact Handling**: The module provides methods to access and manipulate the
15
+ artifacts associated with each run, including retrieving artifact URIs and directories.
16
+
17
+ The `RunCollection` class is designed to work seamlessly with the MLflow tracking
18
+ API, providing a robust solution for managing machine learning experiment runs and
19
+ their associated metadata. This module is particularly useful for data scientists and
20
+ machine learning engineers who need to track and analyze the results of their experiments
21
+ efficiently.
5
22
  """
6
23
 
7
24
  from __future__ import annotations
8
25
 
9
- from dataclasses import dataclass
10
- from functools import cache
26
+ from dataclasses import dataclass, field
11
27
  from itertools import chain
12
- from typing import TYPE_CHECKING, Any, TypeVar
28
+ from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
13
29
 
14
- import mlflow
15
- from mlflow.artifacts import download_artifacts
16
- from mlflow.entities import ViewType
17
30
  from mlflow.entities.run import Run
18
- from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS
19
- from omegaconf import DictConfig, OmegaConf
20
31
 
21
32
  from hydraflow.config import iter_params
33
+ from hydraflow.info import RunCollectionInfo
22
34
 
23
35
  if TYPE_CHECKING:
24
36
  from collections.abc import Callable, Iterator
37
+ from pathlib import Path
25
38
  from typing import Any
26
39
 
27
-
28
- def search_runs(
29
- experiment_ids: list[str] | None = None,
30
- filter_string: str = "",
31
- run_view_type: int = ViewType.ACTIVE_ONLY,
32
- max_results: int = SEARCH_MAX_RESULTS_PANDAS,
33
- order_by: list[str] | None = None,
34
- search_all_experiments: bool = False,
35
- experiment_names: list[str] | None = None,
36
- ) -> RunCollection:
37
- """
38
- Search for Runs that fit the specified criteria.
39
-
40
- This function wraps the `mlflow.search_runs` function and returns the
41
- results as a `RunCollection` object. It allows for flexible searching of
42
- MLflow runs based on various criteria.
43
-
44
- Note:
45
- The returned runs are sorted by their start time in ascending order.
46
-
47
- Args:
48
- experiment_ids (list[str] | None): List of experiment IDs. Search can
49
- work with experiment IDs or experiment names, but not both in the
50
- same call. Values other than ``None`` or ``[]`` will result in
51
- error if ``experiment_names`` is also not ``None`` or ``[]``.
52
- ``None`` will default to the active experiment if ``experiment_names``
53
- is ``None`` or ``[]``.
54
- filter_string (str): Filter query string, defaults to searching all
55
- runs.
56
- run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
57
- or ``ALL`` runs defined in :py:class:`mlflow.entities.ViewType`.
58
- max_results (int): The maximum number of runs to put in the dataframe.
59
- Default is 100,000 to avoid causing out-of-memory issues on the user's
60
- machine.
61
- order_by (list[str] | None): List of columns to order by (e.g.,
62
- "metrics.rmse"). The ``order_by`` column can contain an optional
63
- ``DESC`` or ``ASC`` value. The default is ``ASC``. The default
64
- ordering is to sort by ``start_time DESC``, then ``run_id``.
65
- ``start_time DESC``, then ``run_id``.
66
- search_all_experiments (bool): Boolean specifying whether all
67
- experiments should be searched. Only honored if ``experiment_ids``
68
- is ``[]`` or ``None``.
69
- experiment_names (list[str] | None): List of experiment names. Search
70
- can work with experiment IDs or experiment names, but not both in
71
- the same call. Values other than ``None`` or ``[]`` will result in
72
- error if ``experiment_ids`` is also not ``None`` or ``[]``.
73
- ``experiment_ids`` is also not ``None`` or ``[]``. ``None`` will
74
- default to the active experiment if ``experiment_ids`` is ``None``
75
- or ``[]``.
76
-
77
- Returns:
78
- A `RunCollection` object containing the search results.
79
- """
80
- runs = mlflow.search_runs(
81
- experiment_ids=experiment_ids,
82
- filter_string=filter_string,
83
- run_view_type=run_view_type,
84
- max_results=max_results,
85
- order_by=order_by,
86
- output_format="list",
87
- search_all_experiments=search_all_experiments,
88
- experiment_names=experiment_names,
89
- )
90
- runs = sorted(runs, key=lambda run: run.info.start_time) # type: ignore
91
- return RunCollection(runs) # type: ignore
92
-
93
-
94
- def list_runs(experiment_names: list[str] | None = None) -> RunCollection:
95
- """
96
- List all runs for the specified experiments.
97
-
98
- This function retrieves all runs for the given list of experiment names.
99
- If no experiment names are provided (None), it defaults to searching all runs
100
- for the currently active experiment. If an empty list is provided, the function
101
- will search all runs for all experiments except the "Default" experiment.
102
- The function returns the results as a `RunCollection` object.
103
-
104
- Note:
105
- The returned runs are sorted by their start time in ascending order.
106
-
107
- Args:
108
- experiment_names (list[str] | None): List of experiment names to search
109
- for runs. If None or an empty list is provided, the function will
110
- search the currently active experiment or all experiments except
111
- the "Default" experiment.
112
-
113
- Returns:
114
- A `RunCollection` object containing the runs for the specified experiments.
115
- """
116
- if experiment_names == []:
117
- experiments = mlflow.search_experiments()
118
- experiment_names = [e.name for e in experiments if e.name != "Default"]
119
-
120
- return search_runs(experiment_names=experiment_names)
40
+ from omegaconf import DictConfig
121
41
 
122
42
 
123
43
  T = TypeVar("T")
44
+ P = ParamSpec("P")
124
45
 
125
46
 
126
47
  @dataclass
@@ -130,11 +51,22 @@ class RunCollection:
130
51
 
131
52
  This class provides methods to interact with the runs, such as filtering,
132
53
  retrieving specific runs, and accessing run information.
54
+
55
+ Key Features:
56
+ - Filtering: Easily filter runs based on various criteria.
57
+ - Retrieval: Access specific runs by index or through methods.
58
+ - Metadata: Access run metadata and associated information.
133
59
  """
134
60
 
135
61
  _runs: list[Run]
136
62
  """A list of MLflow Run objects."""
137
63
 
64
+ _info: RunCollectionInfo = field(init=False)
65
+ """A list of MLflow Run objects."""
66
+
67
+ def __post_init__(self):
68
+ self._info = RunCollectionInfo(self)
69
+
138
70
  def __repr__(self) -> str:
139
71
  return f"{self.__class__.__name__}({len(self)})"
140
72
 
@@ -150,6 +82,10 @@ class RunCollection:
150
82
  def __contains__(self, run: Run) -> bool:
151
83
  return run in self._runs
152
84
 
85
+ @property
86
+ def info(self) -> RunCollectionInfo:
87
+ return self._info
88
+
153
89
  def sort(
154
90
  self,
155
91
  key: Callable[[Run], Any] | None = None,
@@ -411,52 +347,81 @@ class RunCollection:
411
347
  """
412
348
  return get_param_dict(self._runs)
413
349
 
414
- def map(self, func: Callable[[Run], T]) -> Iterator[T]:
350
+ def map(
351
+ self,
352
+ func: Callable[Concatenate[Run, P], T],
353
+ *args: P.args,
354
+ **kwargs: P.kwargs,
355
+ ) -> Iterator[T]:
415
356
  """
416
357
  Apply a function to each run in the collection and return an iterator of
417
358
  results.
418
359
 
360
+ This method iterates over each run in the collection and applies the
361
+ provided function to it, along with any additional arguments and
362
+ keyword arguments.
363
+
419
364
  Args:
420
- func (Callable[[Run], T]): A function that takes a run and returns a
421
- result.
365
+ func (Callable[[Run, P], T]): A function that takes a run and
366
+ additional arguments and returns a result.
367
+ *args: Additional arguments to pass to the function.
368
+ **kwargs: Additional keyword arguments to pass to the function.
422
369
 
423
370
  Yields:
424
- Results obtained by applying the function to each run in the
425
- collection.
371
+ Results obtained by applying the function to each run in the collection.
426
372
  """
427
- return (func(run) for run in self._runs)
373
+ return (func(run, *args, **kwargs) for run in self)
428
374
 
429
- def map_run_id(self, func: Callable[[str], T]) -> Iterator[T]:
375
+ def map_run_id(
376
+ self,
377
+ func: Callable[Concatenate[str, P], T],
378
+ *args: P.args,
379
+ **kwargs: P.kwargs,
380
+ ) -> Iterator[T]:
430
381
  """
431
382
  Apply a function to each run id in the collection and return an iterator
432
383
  of results.
433
384
 
434
385
  Args:
435
- func (Callable[[str], T]): A function that takes a run id and returns a
386
+ func (Callable[[str, P], T]): A function that takes a run id and returns a
436
387
  result.
388
+ *args: Additional arguments to pass to the function.
389
+ **kwargs: Additional keyword arguments to pass to the function.
437
390
 
438
391
  Yields:
439
392
  Results obtained by applying the function to each run id in the
440
393
  collection.
441
394
  """
442
- return (func(run.info.run_id) for run in self._runs)
395
+ return (func(run_id, *args, **kwargs) for run_id in self.info.run_id)
443
396
 
444
- def map_config(self, func: Callable[[DictConfig], T]) -> Iterator[T]:
397
+ def map_config(
398
+ self,
399
+ func: Callable[Concatenate[DictConfig, P], T],
400
+ *args: P.args,
401
+ **kwargs: P.kwargs,
402
+ ) -> Iterator[T]:
445
403
  """
446
404
  Apply a function to each run configuration in the collection and return
447
405
  an iterator of results.
448
406
 
449
407
  Args:
450
- func (Callable[[DictConfig], T]): A function that takes a run
408
+ func (Callable[[DictConfig, P], T]): A function that takes a run
451
409
  configuration and returns a result.
410
+ *args: Additional arguments to pass to the function.
411
+ **kwargs: Additional keyword arguments to pass to the function.
452
412
 
453
413
  Yields:
454
414
  Results obtained by applying the function to each run configuration
455
415
  in the collection.
456
416
  """
457
- return (func(load_config(run)) for run in self._runs)
417
+ return (func(config, *args, **kwargs) for config in self.info.config)
458
418
 
459
- def map_uri(self, func: Callable[[str | None], T]) -> Iterator[T]:
419
+ def map_uri(
420
+ self,
421
+ func: Callable[Concatenate[str | None, P], T],
422
+ *args: P.args,
423
+ **kwargs: P.kwargs,
424
+ ) -> Iterator[T]:
460
425
  """
461
426
  Apply a function to each artifact URI in the collection and return an
462
427
  iterator of results.
@@ -466,16 +431,23 @@ class RunCollection:
466
431
  have an artifact URI, None is passed to the function.
467
432
 
468
433
  Args:
469
- func (Callable[[str | None], T]): A function that takes an
470
- artifact URI (string or None) and returns a result.
434
+ func (Callable[[str | None, P], T]): A function that takes an
435
+ artifact URI (string or None) and returns a result.
436
+ *args: Additional arguments to pass to the function.
437
+ **kwargs: Additional keyword arguments to pass to the function.
471
438
 
472
439
  Yields:
473
440
  Results obtained by applying the function to each artifact URI in the
474
441
  collection.
475
442
  """
476
- return (func(run.info.artifact_uri) for run in self._runs)
443
+ return (func(uri, *args, **kwargs) for uri in self.info.artifact_uri)
477
444
 
478
- def map_dir(self, func: Callable[[str], T]) -> Iterator[T]:
445
+ def map_dir(
446
+ self,
447
+ func: Callable[Concatenate[Path, P], T],
448
+ *args: P.args,
449
+ **kwargs: P.kwargs,
450
+ ) -> Iterator[T]:
479
451
  """
480
452
  Apply a function to each artifact directory in the collection and return
481
453
  an iterator of results.
@@ -485,14 +457,16 @@ class RunCollection:
485
457
  path.
486
458
 
487
459
  Args:
488
- func (Callable[[str], T]): A function that takes an artifact directory
460
+ func (Callable[[Path, P], T]): A function that takes an artifact directory
489
461
  path (string) and returns a result.
462
+ *args: Additional arguments to pass to the function.
463
+ **kwargs: Additional keyword arguments to pass to the function.
490
464
 
491
465
  Yields:
492
466
  Results obtained by applying the function to each artifact directory
493
467
  in the collection.
494
468
  """
495
- return (func(download_artifacts(run_id=run.info.run_id)) for run in self._runs)
469
+ return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir)
496
470
 
497
471
  def group_by(self, *names: str | list[str]) -> dict[tuple[str | None, ...], RunCollection]:
498
472
  """
@@ -519,6 +493,25 @@ class RunCollection:
519
493
 
520
494
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
521
495
 
496
+ def group_by_values(self, *names: str | list[str]) -> list[RunCollection]:
497
+ """
498
+ Group runs by specified parameter names.
499
+
500
+ This method groups the runs in the collection based on the values of the
501
+ specified parameters. Each unique combination of parameter values will
502
+ form a separate RunCollection in the returned list.
503
+
504
+ Args:
505
+ *names (str | list[str]): The names of the parameters to group by.
506
+ This can be a single parameter name or multiple names provided
507
+ as separate arguments or as a list.
508
+
509
+ Returns:
510
+ list[RunCollection]: A list of RunCollection objects, where each
511
+ object contains runs that match the specified parameter values.
512
+ """
513
+ return list(self.group_by(*names).values())
514
+
522
515
 
523
516
  def _param_matches(run: Run, key: str, value: Any) -> bool:
524
517
  """
@@ -858,33 +851,3 @@ def get_param_dict(runs: list[Run]) -> dict[str, list[str]]:
858
851
  params[name] = sorted(set(it))
859
852
 
860
853
  return params
861
-
862
-
863
- def load_config(run: Run) -> DictConfig:
864
- """
865
- Load the configuration for a given run.
866
-
867
- This function loads the configuration for the provided Run instance
868
- by downloading the configuration file from the MLflow artifacts and
869
- loading it using OmegaConf. It returns an empty config if
870
- `.hydra/config.yaml` is not found in the run's artifact directory.
871
-
872
- Args:
873
- run (Run): The Run instance for which to load the configuration.
874
-
875
- Returns:
876
- The loaded configuration as a DictConfig object. Returns an empty
877
- DictConfig if the configuration file is not found.
878
- """
879
- run_id = run.info.run_id
880
- return _load_config(run_id)
881
-
882
-
883
- @cache
884
- def _load_config(run_id: str) -> DictConfig:
885
- try:
886
- path = download_artifacts(run_id=run_id, artifact_path=".hydra/config.yaml")
887
- except OSError:
888
- return DictConfig({})
889
-
890
- return OmegaConf.load(path) # type: ignore
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.6
3
+ Version: 0.2.8
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -0,0 +1,12 @@
1
+ hydraflow/__init__.py,sha256=K2xXk5Za_9OkiRmbsgkuWn7EMaTcQOVCPFs5oTP_QFw,483
2
+ hydraflow/asyncio.py,sha256=yh851L315QHzRBwq6r-uwO2oZKgz1JawHp-fswfxT1E,6175
3
+ hydraflow/config.py,sha256=6TCKNQZ3sSrIEvl245T2udwFuknejyN1dMcIVmOHdrQ,2102
4
+ hydraflow/context.py,sha256=G7JMrG70sgBH2qILXl5nkGWNUoRggj518JWUq0ZiJ9E,7776
5
+ hydraflow/info.py,sha256=Vj2sT66Ric63mmaq7Yu8nDFhsGQYO3MCHrxFpapDufc,3458
6
+ hydraflow/mlflow.py,sha256=Q8RGijSURTjRkEDxzi_2Tk9KOx3QK__al5aArGQriHA,7249
7
+ hydraflow/progress.py,sha256=0GJfKnnY_SAHVWpGvLdgOBsogGs8vVofjLuphuUEy2g,4296
8
+ hydraflow/run_collection.py,sha256=Ge-PAsoQBbn7cuow0DYMf5SoBmIXUfZ9ftufN_75Pw8,29963
9
+ hydraflow-0.2.8.dist-info/METADATA,sha256=9CF5S8LdmDUx4sihDqVRvwLLk34FNBmy_Vv6BVoahoc,4181
10
+ hydraflow-0.2.8.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
11
+ hydraflow-0.2.8.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
12
+ hydraflow-0.2.8.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- hydraflow/__init__.py,sha256=l5BrZAfpJHFkQnDRuETZVjDTntMmzOI3CUwnsm2fGzk,460
2
- hydraflow/asyncio.py,sha256=yh851L315QHzRBwq6r-uwO2oZKgz1JawHp-fswfxT1E,6175
3
- hydraflow/config.py,sha256=6TCKNQZ3sSrIEvl245T2udwFuknejyN1dMcIVmOHdrQ,2102
4
- hydraflow/context.py,sha256=8Qn99yCSkCarDDthQ6hjgW80CBBIg0H7fnLvtw4ZXo8,7248
5
- hydraflow/mlflow.py,sha256=gGr0fvFEllduA-ByHMeEamM39zVY_30tjtEbkSZ4lHA,3659
6
- hydraflow/progress.py,sha256=dReFp-AfBuYpjGQnqRmkwPcoyFfe2WCgkklXuo9ZjNg,1709
7
- hydraflow/runs.py,sha256=TETX54OVJPJLi6rjpNcsXAhXH2Q9unhjXhGkOtFtHng,31559
8
- hydraflow-0.2.6.dist-info/METADATA,sha256=yOEx7M9jM5M7MNkLOZShO-DexNqXzIHjSkqbxcNMHQ0,4181
9
- hydraflow-0.2.6.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
10
- hydraflow-0.2.6.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
11
- hydraflow-0.2.6.dist-info/RECORD,,