hydraflow 0.2.5__tar.gz → 0.2.7__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.2.5 → hydraflow-0.2.7}/PKG-INFO +3 -1
- hydraflow-0.2.7/mlruns/0/meta.yaml +6 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/pyproject.toml +3 -1
- {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/__init__.py +2 -2
- hydraflow-0.2.7/src/hydraflow/info.py +63 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/mlflow.py +30 -35
- hydraflow-0.2.7/src/hydraflow/progress.py +131 -0
- hydraflow-0.2.5/src/hydraflow/runs.py → hydraflow-0.2.7/src/hydraflow/run_collection.py +133 -82
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_asyncio.py +1 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_context.py +22 -14
- hydraflow-0.2.7/tests/test_info.py +51 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_log_run.py +1 -1
- hydraflow-0.2.7/tests/test_progress.py +12 -0
- hydraflow-0.2.5/tests/test_runs.py → hydraflow-0.2.7/tests/test_run_collection.py +92 -40
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_watch.py +4 -2
- {hydraflow-0.2.5 → hydraflow-0.2.7}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/.gitattributes +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/.gitignore +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/LICENSE +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/README.md +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/asyncio.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/config.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/context.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/scripts/__init__.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/scripts/log_run.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/scripts/watch.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_config.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_mlflow.py +0 -0
- {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_version.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.7
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://github.com/daizutabi/hydraflow
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -17,7 +17,9 @@ Classifier: Topic :: Documentation
|
|
17
17
|
Classifier: Topic :: Software Development :: Documentation
|
18
18
|
Requires-Python: >=3.10
|
19
19
|
Requires-Dist: hydra-core>1.3
|
20
|
+
Requires-Dist: joblib
|
20
21
|
Requires-Dist: mlflow>2.15
|
22
|
+
Requires-Dist: rich
|
21
23
|
Requires-Dist: setuptools
|
22
24
|
Requires-Dist: watchdog
|
23
25
|
Requires-Dist: watchfiles
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.7"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = "MIT"
|
@@ -21,7 +21,9 @@ classifiers = [
|
|
21
21
|
requires-python = ">=3.10"
|
22
22
|
dependencies = [
|
23
23
|
"hydra-core>1.3",
|
24
|
+
"joblib",
|
24
25
|
"mlflow>2.15",
|
26
|
+
"rich",
|
25
27
|
"setuptools",
|
26
28
|
"watchdog",
|
27
29
|
"watchfiles",
|
@@ -1,9 +1,9 @@
|
|
1
1
|
from .context import chdir_artifact, log_run, start_run, watch
|
2
|
+
from .info import load_config
|
2
3
|
from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
|
3
|
-
from .
|
4
|
+
from .run_collection import (
|
4
5
|
RunCollection,
|
5
6
|
list_runs,
|
6
|
-
load_config,
|
7
7
|
search_runs,
|
8
8
|
)
|
9
9
|
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from omegaconf import DictConfig, OmegaConf
|
6
|
+
|
7
|
+
from hydraflow.mlflow import get_artifact_dir
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from pathlib import Path
|
11
|
+
|
12
|
+
from mlflow.entities import Run
|
13
|
+
|
14
|
+
from hydraflow.run_collection import RunCollection
|
15
|
+
|
16
|
+
|
17
|
+
class RunCollectionInfo:
|
18
|
+
def __init__(self, runs: RunCollection):
|
19
|
+
self._runs = runs
|
20
|
+
|
21
|
+
@property
|
22
|
+
def run_id(self) -> list[str]:
|
23
|
+
return [run.info.run_id for run in self._runs]
|
24
|
+
|
25
|
+
@property
|
26
|
+
def params(self) -> list[dict[str, str]]:
|
27
|
+
return [run.data.params for run in self._runs]
|
28
|
+
|
29
|
+
@property
|
30
|
+
def metrics(self) -> list[dict[str, float]]:
|
31
|
+
return [run.data.metrics for run in self._runs]
|
32
|
+
|
33
|
+
@property
|
34
|
+
def artifact_uri(self) -> list[str | None]:
|
35
|
+
return [run.info.artifact_uri for run in self._runs]
|
36
|
+
|
37
|
+
@property
|
38
|
+
def artifact_dir(self) -> list[Path]:
|
39
|
+
return [get_artifact_dir(run) for run in self._runs]
|
40
|
+
|
41
|
+
@property
|
42
|
+
def config(self) -> list[DictConfig]:
|
43
|
+
return [load_config(run) for run in self._runs]
|
44
|
+
|
45
|
+
|
46
|
+
def load_config(run: Run) -> DictConfig:
|
47
|
+
"""
|
48
|
+
Load the configuration for a given run.
|
49
|
+
|
50
|
+
This function loads the configuration for the provided Run instance
|
51
|
+
by downloading the configuration file from the MLflow artifacts and
|
52
|
+
loading it using OmegaConf. It returns an empty config if
|
53
|
+
`.hydra/config.yaml` is not found in the run's artifact directory.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
run (Run): The Run instance for which to load the configuration.
|
57
|
+
|
58
|
+
Returns:
|
59
|
+
The loaded configuration as a DictConfig object. Returns an empty
|
60
|
+
DictConfig if the configuration file is not found.
|
61
|
+
"""
|
62
|
+
path = get_artifact_dir(run) / ".hydra/config.yaml"
|
63
|
+
return OmegaConf.load(path) # type: ignore
|
@@ -17,6 +17,7 @@ from hydraflow.config import iter_params
|
|
17
17
|
|
18
18
|
if TYPE_CHECKING:
|
19
19
|
from mlflow.entities.experiment import Experiment
|
20
|
+
from mlflow.entities.run import Run
|
20
21
|
|
21
22
|
|
22
23
|
def set_experiment(
|
@@ -65,60 +66,54 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
|
|
65
66
|
mlflow.log_param(key, value, synchronous=synchronous)
|
66
67
|
|
67
68
|
|
68
|
-
def get_artifact_dir(
|
69
|
-
artifact_path: str | None = None,
|
70
|
-
*,
|
71
|
-
run_id: str | None = None,
|
72
|
-
) -> Path:
|
69
|
+
def get_artifact_dir(run: Run | None = None) -> Path:
|
73
70
|
"""
|
74
|
-
|
71
|
+
Retrieve the artifact directory for the given run.
|
75
72
|
|
76
|
-
This function
|
77
|
-
using MLflow, downloads the artifacts to a local directory, and returns
|
78
|
-
the path to that directory.
|
73
|
+
This function uses MLflow to get the artifact directory for the given run.
|
79
74
|
|
80
75
|
Args:
|
81
|
-
|
82
|
-
directory. Defaults to None.
|
83
|
-
run_id (str | None): The run ID for which to get the artifact directory.
|
76
|
+
run (Run | None): The run object. Defaults to None.
|
84
77
|
|
85
78
|
Returns:
|
86
79
|
The local path to the directory where the artifacts are downloaded.
|
87
80
|
"""
|
88
|
-
if
|
89
|
-
uri = mlflow.get_artifact_uri(
|
81
|
+
if run is None:
|
82
|
+
uri = mlflow.get_artifact_uri()
|
90
83
|
else:
|
91
|
-
uri = artifact_utils.get_artifact_uri(run_id
|
84
|
+
uri = artifact_utils.get_artifact_uri(run.info.run_id)
|
92
85
|
|
93
|
-
|
86
|
+
return Path(mlflow.artifacts.download_artifacts(uri))
|
94
87
|
|
95
|
-
return Path(dir)
|
96
88
|
|
89
|
+
def get_hydra_output_dir(*, run: Run | None = None) -> Path:
|
90
|
+
"""
|
91
|
+
Retrieve the Hydra output directory for the given run.
|
92
|
+
|
93
|
+
This function returns the Hydra output directory. If no run is provided,
|
94
|
+
it retrieves the output directory from the current Hydra configuration.
|
95
|
+
If a run is provided, it retrieves the artifact path for the run, loads
|
96
|
+
the Hydra configuration from the downloaded artifacts, and returns the
|
97
|
+
output directory specified in that configuration.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
run (Run | None): The run object. Defaults to None.
|
101
|
+
|
102
|
+
Returns:
|
103
|
+
Path: The path to the Hydra output directory.
|
97
104
|
|
98
|
-
|
99
|
-
|
105
|
+
Raises:
|
106
|
+
FileNotFoundError: If the Hydra configuration file is not found
|
107
|
+
in the artifacts.
|
108
|
+
"""
|
109
|
+
if run is None:
|
100
110
|
hc = HydraConfig.get()
|
101
111
|
return Path(hc.runtime.output_dir)
|
102
112
|
|
103
|
-
path = get_artifact_dir(
|
113
|
+
path = get_artifact_dir(run) / ".hydra/hydra.yaml"
|
104
114
|
|
105
115
|
if path.exists():
|
106
116
|
hc = OmegaConf.load(path)
|
107
117
|
return Path(hc.hydra.runtime.output_dir)
|
108
118
|
|
109
119
|
raise FileNotFoundError
|
110
|
-
|
111
|
-
|
112
|
-
# def log_hydra_output_dir(run: Run_ | Series | str) -> None:
|
113
|
-
# """
|
114
|
-
# Log the Hydra output directory.
|
115
|
-
|
116
|
-
# Args:
|
117
|
-
# run: The run object.
|
118
|
-
|
119
|
-
# Returns:
|
120
|
-
# None
|
121
|
-
# """
|
122
|
-
# output_dir = get_hydra_output_dir(run)
|
123
|
-
# run_id = run if isinstance(run, str) else run.info.run_id
|
124
|
-
# mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
|
@@ -0,0 +1,131 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
import joblib
|
6
|
+
from rich.progress import Progress
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from collections.abc import Iterable
|
10
|
+
|
11
|
+
from rich.progress import ProgressColumn
|
12
|
+
|
13
|
+
|
14
|
+
def multi_task_progress(
|
15
|
+
iterables: Iterable[Iterable[int | tuple[int, int]]],
|
16
|
+
*columns: ProgressColumn | str,
|
17
|
+
n_jobs: int = -1,
|
18
|
+
description: str = "#{:0>3}",
|
19
|
+
main_description: str = "main",
|
20
|
+
transient: bool | None = None,
|
21
|
+
**kwargs,
|
22
|
+
) -> None:
|
23
|
+
"""
|
24
|
+
Render auto-updating progress bars for multiple tasks concurrently.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
iterables (Iterable[Iterable[int | tuple[int, int]]]): A collection of
|
28
|
+
iterables, each representing a task. Each iterable can yield
|
29
|
+
integers (completed) or tuples of integers (completed, total).
|
30
|
+
*columns (ProgressColumn | str): Additional columns to display in the
|
31
|
+
progress bars.
|
32
|
+
n_jobs (int, optional): Number of jobs to run in parallel. Defaults to
|
33
|
+
-1, which means using all processors.
|
34
|
+
description (str, optional): Format string for describing tasks. Defaults to
|
35
|
+
"#{:0>3}".
|
36
|
+
main_description (str, optional): Description for the main task.
|
37
|
+
Defaults to "main".
|
38
|
+
transient (bool | None, optional): Whether to remove the progress bar
|
39
|
+
after completion. Defaults to None.
|
40
|
+
**kwargs: Additional keyword arguments passed to the Progress instance.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
None
|
44
|
+
"""
|
45
|
+
if not columns:
|
46
|
+
columns = Progress.get_default_columns()
|
47
|
+
|
48
|
+
iterables = list(iterables)
|
49
|
+
|
50
|
+
with Progress(*columns, transient=transient or False, **kwargs) as progress:
|
51
|
+
n = len(iterables)
|
52
|
+
|
53
|
+
task_main = progress.add_task(main_description, total=None) if n > 1 else None
|
54
|
+
tasks = [
|
55
|
+
progress.add_task(description.format(i), start=False, total=None) for i in range(n)
|
56
|
+
]
|
57
|
+
|
58
|
+
total = {}
|
59
|
+
completed = {}
|
60
|
+
|
61
|
+
def func(i: int) -> None:
|
62
|
+
completed[i] = 0
|
63
|
+
total[i] = None
|
64
|
+
progress.start_task(tasks[i])
|
65
|
+
|
66
|
+
for index in iterables[i]:
|
67
|
+
if isinstance(index, tuple):
|
68
|
+
completed[i], total[i] = index[0] + 1, index[1]
|
69
|
+
else:
|
70
|
+
completed[i] = index + 1
|
71
|
+
|
72
|
+
progress.update(tasks[i], total=total[i], completed=completed[i])
|
73
|
+
if task_main is not None:
|
74
|
+
if all(t is not None for t in total.values()):
|
75
|
+
t = sum(total.values())
|
76
|
+
else:
|
77
|
+
t = None
|
78
|
+
c = sum(completed.values())
|
79
|
+
progress.update(task_main, total=t, completed=c)
|
80
|
+
|
81
|
+
if transient or n > 1:
|
82
|
+
progress.remove_task(tasks[i])
|
83
|
+
|
84
|
+
if n > 1:
|
85
|
+
it = (joblib.delayed(func)(i) for i in range(n))
|
86
|
+
joblib.Parallel(n_jobs, prefer="threads")(it)
|
87
|
+
|
88
|
+
else:
|
89
|
+
func(0)
|
90
|
+
|
91
|
+
|
92
|
+
if __name__ == "__main__":
|
93
|
+
import random
|
94
|
+
import time
|
95
|
+
|
96
|
+
from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn, TimeElapsedColumn
|
97
|
+
|
98
|
+
from hydraflow.progress import multi_task_progress
|
99
|
+
|
100
|
+
def task(total):
|
101
|
+
for i in range(total or 90):
|
102
|
+
if total is None:
|
103
|
+
yield i
|
104
|
+
else:
|
105
|
+
yield i, total
|
106
|
+
time.sleep(random.random() / 30)
|
107
|
+
|
108
|
+
def multi_task_progress_test(unknown_total: bool):
|
109
|
+
tasks = [task(random.randint(80, 100)) for _ in range(4)]
|
110
|
+
if unknown_total:
|
111
|
+
tasks = [task(None), *tasks, task(None)]
|
112
|
+
|
113
|
+
columns = [
|
114
|
+
SpinnerColumn(),
|
115
|
+
*Progress.get_default_columns(),
|
116
|
+
MofNCompleteColumn(),
|
117
|
+
TimeElapsedColumn(),
|
118
|
+
]
|
119
|
+
|
120
|
+
kwargs = {}
|
121
|
+
if unknown_total:
|
122
|
+
kwargs["main_description"] = "unknown"
|
123
|
+
|
124
|
+
multi_task_progress(tasks, *columns, n_jobs=4, **kwargs)
|
125
|
+
|
126
|
+
multi_task_progress_test(False)
|
127
|
+
multi_task_progress_test(True)
|
128
|
+
multi_task_progress([task(100)])
|
129
|
+
multi_task_progress([task(None)], description="unknown")
|
130
|
+
multi_task_progress([task(100), task(None)], main_description="transient", transient=True)
|
131
|
+
multi_task_progress([task(100)], description="transient", transient=True)
|
@@ -6,24 +6,25 @@ runs, retrieve run information, log artifacts, and load configurations.
|
|
6
6
|
|
7
7
|
from __future__ import annotations
|
8
8
|
|
9
|
-
from dataclasses import dataclass
|
10
|
-
from functools import cache
|
9
|
+
from dataclasses import dataclass, field
|
11
10
|
from itertools import chain
|
12
|
-
from typing import TYPE_CHECKING, Any, TypeVar
|
11
|
+
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
|
13
12
|
|
14
13
|
import mlflow
|
15
|
-
from mlflow.artifacts import download_artifacts
|
16
14
|
from mlflow.entities import ViewType
|
17
15
|
from mlflow.entities.run import Run
|
18
16
|
from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS
|
19
|
-
from omegaconf import DictConfig, OmegaConf
|
20
17
|
|
21
18
|
from hydraflow.config import iter_params
|
19
|
+
from hydraflow.info import RunCollectionInfo
|
22
20
|
|
23
21
|
if TYPE_CHECKING:
|
24
22
|
from collections.abc import Callable, Iterator
|
23
|
+
from pathlib import Path
|
25
24
|
from typing import Any
|
26
25
|
|
26
|
+
from omegaconf import DictConfig
|
27
|
+
|
27
28
|
|
28
29
|
def search_runs(
|
29
30
|
experiment_ids: list[str] | None = None,
|
@@ -51,13 +52,6 @@ def search_runs(
|
|
51
52
|
error if ``experiment_names`` is also not ``None`` or ``[]``.
|
52
53
|
``None`` will default to the active experiment if ``experiment_names``
|
53
54
|
is ``None`` or ``[]``.
|
54
|
-
experiment_ids (list[str] | None): List of experiment IDs. Search can
|
55
|
-
work with experiment IDs or experiment names, but not both in the
|
56
|
-
same call. Values other than ``None`` or ``[]`` will result in
|
57
|
-
error if ``experiment_names`` is also not ``None`` or ``[]``.
|
58
|
-
``experiment_names`` is also not ``None`` or ``[]``. ``None`` will
|
59
|
-
default to the active experiment if ``experiment_names`` is ``None``
|
60
|
-
or ``[]``.
|
61
55
|
filter_string (str): Filter query string, defaults to searching all
|
62
56
|
runs.
|
63
57
|
run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
|
@@ -128,6 +122,7 @@ def list_runs(experiment_names: list[str] | None = None) -> RunCollection:
|
|
128
122
|
|
129
123
|
|
130
124
|
T = TypeVar("T")
|
125
|
+
P = ParamSpec("P")
|
131
126
|
|
132
127
|
|
133
128
|
@dataclass
|
@@ -142,6 +137,12 @@ class RunCollection:
|
|
142
137
|
_runs: list[Run]
|
143
138
|
"""A list of MLflow Run objects."""
|
144
139
|
|
140
|
+
_info: RunCollectionInfo = field(init=False)
|
141
|
+
"""A list of MLflow Run objects."""
|
142
|
+
|
143
|
+
def __post_init__(self):
|
144
|
+
self._info = RunCollectionInfo(self)
|
145
|
+
|
145
146
|
def __repr__(self) -> str:
|
146
147
|
return f"{self.__class__.__name__}({len(self)})"
|
147
148
|
|
@@ -157,6 +158,10 @@ class RunCollection:
|
|
157
158
|
def __contains__(self, run: Run) -> bool:
|
158
159
|
return run in self._runs
|
159
160
|
|
161
|
+
@property
|
162
|
+
def info(self) -> RunCollectionInfo:
|
163
|
+
return self._info
|
164
|
+
|
160
165
|
def sort(
|
161
166
|
self,
|
162
167
|
key: Callable[[Run], Any] | None = None,
|
@@ -418,52 +423,81 @@ class RunCollection:
|
|
418
423
|
"""
|
419
424
|
return get_param_dict(self._runs)
|
420
425
|
|
421
|
-
def map(
|
426
|
+
def map(
|
427
|
+
self,
|
428
|
+
func: Callable[Concatenate[Run, P], T],
|
429
|
+
*args: P.args,
|
430
|
+
**kwargs: P.kwargs,
|
431
|
+
) -> Iterator[T]:
|
422
432
|
"""
|
423
433
|
Apply a function to each run in the collection and return an iterator of
|
424
434
|
results.
|
425
435
|
|
436
|
+
This method iterates over each run in the collection and applies the
|
437
|
+
provided function to it, along with any additional arguments and
|
438
|
+
keyword arguments.
|
439
|
+
|
426
440
|
Args:
|
427
|
-
func (Callable[[Run], T]): A function that takes a run and
|
428
|
-
result.
|
441
|
+
func (Callable[[Run, P], T]): A function that takes a run and
|
442
|
+
additional arguments and returns a result.
|
443
|
+
*args: Additional arguments to pass to the function.
|
444
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
429
445
|
|
430
446
|
Yields:
|
431
|
-
Results obtained by applying the function to each run in the
|
432
|
-
collection.
|
447
|
+
Results obtained by applying the function to each run in the collection.
|
433
448
|
"""
|
434
|
-
return (func(run) for run in self
|
449
|
+
return (func(run, *args, **kwargs) for run in self)
|
435
450
|
|
436
|
-
def map_run_id(
|
451
|
+
def map_run_id(
|
452
|
+
self,
|
453
|
+
func: Callable[Concatenate[str, P], T],
|
454
|
+
*args: P.args,
|
455
|
+
**kwargs: P.kwargs,
|
456
|
+
) -> Iterator[T]:
|
437
457
|
"""
|
438
458
|
Apply a function to each run id in the collection and return an iterator
|
439
459
|
of results.
|
440
460
|
|
441
461
|
Args:
|
442
|
-
func (Callable[[str], T]): A function that takes a run id and returns a
|
462
|
+
func (Callable[[str, P], T]): A function that takes a run id and returns a
|
443
463
|
result.
|
464
|
+
*args: Additional arguments to pass to the function.
|
465
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
444
466
|
|
445
467
|
Yields:
|
446
468
|
Results obtained by applying the function to each run id in the
|
447
469
|
collection.
|
448
470
|
"""
|
449
|
-
return (func(
|
471
|
+
return (func(run_id, *args, **kwargs) for run_id in self.info.run_id)
|
450
472
|
|
451
|
-
def map_config(
|
473
|
+
def map_config(
|
474
|
+
self,
|
475
|
+
func: Callable[Concatenate[DictConfig, P], T],
|
476
|
+
*args: P.args,
|
477
|
+
**kwargs: P.kwargs,
|
478
|
+
) -> Iterator[T]:
|
452
479
|
"""
|
453
480
|
Apply a function to each run configuration in the collection and return
|
454
481
|
an iterator of results.
|
455
482
|
|
456
483
|
Args:
|
457
|
-
func (Callable[[DictConfig], T]): A function that takes a run
|
484
|
+
func (Callable[[DictConfig, P], T]): A function that takes a run
|
458
485
|
configuration and returns a result.
|
486
|
+
*args: Additional arguments to pass to the function.
|
487
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
459
488
|
|
460
489
|
Yields:
|
461
490
|
Results obtained by applying the function to each run configuration
|
462
491
|
in the collection.
|
463
492
|
"""
|
464
|
-
return (func(
|
493
|
+
return (func(config, *args, **kwargs) for config in self.info.config)
|
465
494
|
|
466
|
-
def map_uri(
|
495
|
+
def map_uri(
|
496
|
+
self,
|
497
|
+
func: Callable[Concatenate[str | None, P], T],
|
498
|
+
*args: P.args,
|
499
|
+
**kwargs: P.kwargs,
|
500
|
+
) -> Iterator[T]:
|
467
501
|
"""
|
468
502
|
Apply a function to each artifact URI in the collection and return an
|
469
503
|
iterator of results.
|
@@ -473,16 +507,23 @@ class RunCollection:
|
|
473
507
|
have an artifact URI, None is passed to the function.
|
474
508
|
|
475
509
|
Args:
|
476
|
-
func (Callable[[str | None], T]): A function that takes an
|
477
|
-
|
510
|
+
func (Callable[[str | None, P], T]): A function that takes an
|
511
|
+
artifact URI (string or None) and returns a result.
|
512
|
+
*args: Additional arguments to pass to the function.
|
513
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
478
514
|
|
479
515
|
Yields:
|
480
516
|
Results obtained by applying the function to each artifact URI in the
|
481
517
|
collection.
|
482
518
|
"""
|
483
|
-
return (func(
|
519
|
+
return (func(uri, *args, **kwargs) for uri in self.info.artifact_uri)
|
484
520
|
|
485
|
-
def map_dir(
|
521
|
+
def map_dir(
|
522
|
+
self,
|
523
|
+
func: Callable[Concatenate[Path, P], T],
|
524
|
+
*args: P.args,
|
525
|
+
**kwargs: P.kwargs,
|
526
|
+
) -> Iterator[T]:
|
486
527
|
"""
|
487
528
|
Apply a function to each artifact directory in the collection and return
|
488
529
|
an iterator of results.
|
@@ -492,42 +533,61 @@ class RunCollection:
|
|
492
533
|
path.
|
493
534
|
|
494
535
|
Args:
|
495
|
-
func (Callable[[
|
536
|
+
func (Callable[[Path, P], T]): A function that takes an artifact directory
|
496
537
|
path (string) and returns a result.
|
538
|
+
*args: Additional arguments to pass to the function.
|
539
|
+
**kwargs: Additional keyword arguments to pass to the function.
|
497
540
|
|
498
541
|
Yields:
|
499
542
|
Results obtained by applying the function to each artifact directory
|
500
543
|
in the collection.
|
501
544
|
"""
|
502
|
-
return (func(
|
545
|
+
return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir)
|
503
546
|
|
504
|
-
def group_by(
|
505
|
-
self, names: list[str] | None = None, *args
|
506
|
-
) -> dict[tuple[str, ...], RunCollection]:
|
547
|
+
def group_by(self, *names: str | list[str]) -> dict[tuple[str | None, ...], RunCollection]:
|
507
548
|
"""
|
508
|
-
Group
|
509
|
-
|
549
|
+
Group runs by specified parameter names.
|
550
|
+
|
551
|
+
This method groups the runs in the collection based on the values of the
|
552
|
+
specified parameters. Each unique combination of parameter values will
|
553
|
+
form a key in the returned dictionary.
|
510
554
|
|
511
555
|
Args:
|
512
|
-
names (list[str]
|
513
|
-
|
556
|
+
*names (str | list[str]): The names of the parameters to group by.
|
557
|
+
This can be a single parameter name or multiple names provided
|
558
|
+
as separate arguments or as a list.
|
514
559
|
|
515
560
|
Returns:
|
516
|
-
|
517
|
-
are the
|
561
|
+
dict[tuple[str | None, ...], RunCollection]: A dictionary where the keys
|
562
|
+
are tuples of parameter values and the values are RunCollection objects
|
563
|
+
containing the runs that match those parameter values.
|
518
564
|
"""
|
519
|
-
|
520
|
-
names.extend(args)
|
521
|
-
|
522
|
-
grouped_runs = {}
|
565
|
+
grouped_runs: dict[tuple[str | None, ...], list[Run]] = {}
|
523
566
|
for run in self._runs:
|
524
|
-
key = get_params(run, names)
|
525
|
-
|
526
|
-
grouped_runs[key] = []
|
527
|
-
grouped_runs[key].append(run)
|
567
|
+
key = get_params(run, *names)
|
568
|
+
grouped_runs.setdefault(key, []).append(run)
|
528
569
|
|
529
570
|
return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
|
530
571
|
|
572
|
+
def group_by_values(self, *names: str | list[str]) -> list[RunCollection]:
|
573
|
+
"""
|
574
|
+
Group runs by specified parameter names.
|
575
|
+
|
576
|
+
This method groups the runs in the collection based on the values of the
|
577
|
+
specified parameters. Each unique combination of parameter values will
|
578
|
+
form a separate RunCollection in the returned list.
|
579
|
+
|
580
|
+
Args:
|
581
|
+
*names (str | list[str]): The names of the parameters to group by.
|
582
|
+
This can be a single parameter name or multiple names provided
|
583
|
+
as separate arguments or as a list.
|
584
|
+
|
585
|
+
Returns:
|
586
|
+
list[RunCollection]: A list of RunCollection objects, where each
|
587
|
+
object contains runs that match the specified parameter values.
|
588
|
+
"""
|
589
|
+
return list(self.group_by(*names).values())
|
590
|
+
|
531
591
|
|
532
592
|
def _param_matches(run: Run, key: str, value: Any) -> bool:
|
533
593
|
"""
|
@@ -792,11 +852,32 @@ def try_get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
|
|
792
852
|
raise ValueError(msg)
|
793
853
|
|
794
854
|
|
795
|
-
def get_params(run: Run, names: list[str]
|
796
|
-
|
797
|
-
|
855
|
+
def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
|
856
|
+
"""
|
857
|
+
Retrieve the values of specified parameters from the given run.
|
858
|
+
|
859
|
+
This function extracts the values of the parameters identified by the
|
860
|
+
provided names from the specified run. It can accept both individual
|
861
|
+
parameter names and lists of parameter names.
|
862
|
+
|
863
|
+
Args:
|
864
|
+
run (Run): The run object from which to extract parameter values.
|
865
|
+
*names (str | list[str]): The names of the parameters to retrieve.
|
866
|
+
This can be a single parameter name or multiple names provided
|
867
|
+
as separate arguments or as a list.
|
868
|
+
|
869
|
+
Returns:
|
870
|
+
tuple[str | None, ...]: A tuple containing the values of the specified
|
871
|
+
parameters in the order they were provided.
|
872
|
+
"""
|
873
|
+
names_ = []
|
874
|
+
for name in names:
|
875
|
+
if isinstance(name, list):
|
876
|
+
names_.extend(name)
|
877
|
+
else:
|
878
|
+
names_.append(name)
|
798
879
|
|
799
|
-
return tuple(run.data.params
|
880
|
+
return tuple(run.data.params.get(name) for name in names_)
|
800
881
|
|
801
882
|
|
802
883
|
def get_param_names(runs: list[Run]) -> list[str]:
|
@@ -846,33 +927,3 @@ def get_param_dict(runs: list[Run]) -> dict[str, list[str]]:
|
|
846
927
|
params[name] = sorted(set(it))
|
847
928
|
|
848
929
|
return params
|
849
|
-
|
850
|
-
|
851
|
-
def load_config(run: Run) -> DictConfig:
|
852
|
-
"""
|
853
|
-
Load the configuration for a given run.
|
854
|
-
|
855
|
-
This function loads the configuration for the provided Run instance
|
856
|
-
by downloading the configuration file from the MLflow artifacts and
|
857
|
-
loading it using OmegaConf. It returns an empty config if
|
858
|
-
`.hydra/config.yaml` is not found in the run's artifact directory.
|
859
|
-
|
860
|
-
Args:
|
861
|
-
run (Run): The Run instance for which to load the configuration.
|
862
|
-
|
863
|
-
Returns:
|
864
|
-
The loaded configuration as a DictConfig object. Returns an empty
|
865
|
-
DictConfig if the configuration file is not found.
|
866
|
-
"""
|
867
|
-
run_id = run.info.run_id
|
868
|
-
return _load_config(run_id)
|
869
|
-
|
870
|
-
|
871
|
-
@cache
|
872
|
-
def _load_config(run_id: str) -> DictConfig:
|
873
|
-
try:
|
874
|
-
path = download_artifacts(run_id=run_id, artifact_path=".hydra/config.yaml")
|
875
|
-
except OSError:
|
876
|
-
return DictConfig({})
|
877
|
-
|
878
|
-
return OmegaConf.load(path) # type: ignore
|
@@ -1,15 +1,17 @@
|
|
1
|
+
import time
|
2
|
+
from pathlib import Path
|
1
3
|
from unittest.mock import MagicMock, patch
|
2
4
|
|
3
5
|
import mlflow
|
4
6
|
import pytest
|
5
7
|
|
6
8
|
from hydraflow.context import log_run, start_run, watch
|
7
|
-
from hydraflow.
|
9
|
+
from hydraflow.run_collection import RunCollection
|
8
10
|
|
9
11
|
|
10
12
|
@pytest.fixture
|
11
13
|
def runs(monkeypatch, tmp_path):
|
12
|
-
from hydraflow.
|
14
|
+
from hydraflow.run_collection import list_runs
|
13
15
|
|
14
16
|
monkeypatch.chdir(tmp_path)
|
15
17
|
|
@@ -17,7 +19,7 @@ def runs(monkeypatch, tmp_path):
|
|
17
19
|
patch("hydraflow.context.HydraConfig.get") as mock_hydra_config,
|
18
20
|
patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
|
19
21
|
):
|
20
|
-
mock_hydra_config.return_value.runtime.output_dir =
|
22
|
+
mock_hydra_config.return_value.runtime.output_dir = tmp_path.as_posix()
|
21
23
|
mock_log_artifacts.return_value = None
|
22
24
|
|
23
25
|
mlflow.set_experiment("test_run")
|
@@ -49,7 +51,7 @@ def test_runs_params_dict(runs: RunCollection, i: int):
|
|
49
51
|
assert runs[i].data.params["d.i"] == str(i)
|
50
52
|
|
51
53
|
|
52
|
-
def test_log_run_error_handling():
|
54
|
+
def test_log_run_error_handling(tmp_path: Path):
|
53
55
|
config = MagicMock()
|
54
56
|
config.some_param = "value"
|
55
57
|
|
@@ -59,7 +61,7 @@ def test_log_run_error_handling():
|
|
59
61
|
patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
|
60
62
|
):
|
61
63
|
mock_log_params.side_effect = Exception("Test exception")
|
62
|
-
mock_hydra_config.return_value.runtime.output_dir =
|
64
|
+
mock_hydra_config.return_value.runtime.output_dir = tmp_path.as_posix()
|
63
65
|
mock_log_artifacts.return_value = None
|
64
66
|
|
65
67
|
with pytest.raises(Exception, match="Test exception"):
|
@@ -67,14 +69,20 @@ def test_log_run_error_handling():
|
|
67
69
|
pass
|
68
70
|
|
69
71
|
|
70
|
-
def
|
71
|
-
|
72
|
-
|
72
|
+
def test_watch_context_manager(tmp_path: Path):
|
73
|
+
test_dir = tmp_path / "test_watch"
|
74
|
+
test_dir.mkdir(parents=True, exist_ok=True)
|
75
|
+
test_file = test_dir / "test_file.txt"
|
73
76
|
|
74
|
-
|
75
|
-
mock_observer_instance = mock_observer.return_value
|
76
|
-
mock_observer_instance.start.side_effect = Exception("Test exception")
|
77
|
+
called = []
|
77
78
|
|
78
|
-
|
79
|
-
|
80
|
-
|
79
|
+
def mock_func(path: Path):
|
80
|
+
assert path == test_file
|
81
|
+
called.append(path)
|
82
|
+
|
83
|
+
with watch(mock_func, test_dir):
|
84
|
+
test_file.write_text("new content")
|
85
|
+
time.sleep(1)
|
86
|
+
|
87
|
+
assert len(called) == 1
|
88
|
+
assert called[0] == test_file
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import mlflow
|
4
|
+
import pytest
|
5
|
+
|
6
|
+
from hydraflow.run_collection import RunCollection
|
7
|
+
|
8
|
+
|
9
|
+
@pytest.fixture
|
10
|
+
def runs(monkeypatch, tmp_path):
|
11
|
+
from hydraflow.run_collection import search_runs
|
12
|
+
|
13
|
+
monkeypatch.chdir(tmp_path)
|
14
|
+
|
15
|
+
mlflow.set_experiment("test_info")
|
16
|
+
for x in range(3):
|
17
|
+
with mlflow.start_run(run_name=f"{x}"):
|
18
|
+
mlflow.log_param("p", x)
|
19
|
+
mlflow.log_metric("metric1", x + 1)
|
20
|
+
mlflow.log_metric("metric2", x + 2)
|
21
|
+
|
22
|
+
x = search_runs()
|
23
|
+
assert isinstance(x, RunCollection)
|
24
|
+
return x
|
25
|
+
|
26
|
+
|
27
|
+
def test_info_run_id(runs: RunCollection):
|
28
|
+
assert len(runs.info.run_id) == 3
|
29
|
+
|
30
|
+
|
31
|
+
def test_info_params(runs: RunCollection):
|
32
|
+
assert runs.info.params == [{"p": "0"}, {"p": "1"}, {"p": "2"}]
|
33
|
+
|
34
|
+
|
35
|
+
def test_info_metrics(runs: RunCollection):
|
36
|
+
m = runs.info.metrics
|
37
|
+
assert m[0] == {"metric1": 1, "metric2": 2}
|
38
|
+
assert m[1] == {"metric1": 2, "metric2": 3}
|
39
|
+
assert m[2] == {"metric1": 3, "metric2": 4}
|
40
|
+
|
41
|
+
|
42
|
+
def test_info_artifact_uri(runs: RunCollection):
|
43
|
+
uri = runs.info.artifact_uri
|
44
|
+
assert all(u.startswith("file://") for u in uri) # type: ignore
|
45
|
+
assert all(u.endswith("/artifacts") for u in uri) # type: ignore
|
46
|
+
|
47
|
+
|
48
|
+
def test_info_artifact_dir(runs: RunCollection):
|
49
|
+
dir = runs.info.artifact_dir
|
50
|
+
assert all(isinstance(d, Path) for d in dir)
|
51
|
+
assert all(d.stem == "artifacts" for d in dir) # type: ignore
|
@@ -0,0 +1,12 @@
|
|
1
|
+
import sys
|
2
|
+
from subprocess import run
|
3
|
+
|
4
|
+
import pytest
|
5
|
+
|
6
|
+
|
7
|
+
@pytest.mark.skipif(
|
8
|
+
sys.platform == "win32", reason="'cp932' codec can't encode character '\\u2807'"
|
9
|
+
)
|
10
|
+
def test_progress_bar():
|
11
|
+
cp = run([sys.executable, "-m", "hydraflow.progress"])
|
12
|
+
assert cp.returncode == 0
|
@@ -5,14 +5,13 @@ from pathlib import Path
|
|
5
5
|
import mlflow
|
6
6
|
import pytest
|
7
7
|
from mlflow.entities import Run
|
8
|
-
from omegaconf import DictConfig
|
9
8
|
|
10
|
-
from hydraflow.
|
9
|
+
from hydraflow.run_collection import RunCollection
|
11
10
|
|
12
11
|
|
13
12
|
@pytest.fixture
|
14
13
|
def runs(monkeypatch, tmp_path):
|
15
|
-
from hydraflow.
|
14
|
+
from hydraflow.run_collection import search_runs
|
16
15
|
|
17
16
|
monkeypatch.chdir(tmp_path)
|
18
17
|
|
@@ -39,13 +38,13 @@ def test_search_runs_sorted(run_list: list[Run]):
|
|
39
38
|
|
40
39
|
|
41
40
|
def test_filter_none(run_list: list[Run]):
|
42
|
-
from hydraflow.
|
41
|
+
from hydraflow.run_collection import filter_runs
|
43
42
|
|
44
43
|
assert run_list == filter_runs(run_list)
|
45
44
|
|
46
45
|
|
47
46
|
def test_filter_one(run_list: list[Run]):
|
48
|
-
from hydraflow.
|
47
|
+
from hydraflow.run_collection import filter_runs
|
49
48
|
|
50
49
|
assert len(run_list) == 6
|
51
50
|
x = filter_runs(run_list, {"p": 1})
|
@@ -55,7 +54,7 @@ def test_filter_one(run_list: list[Run]):
|
|
55
54
|
|
56
55
|
|
57
56
|
def test_filter_all(run_list: list[Run]):
|
58
|
-
from hydraflow.
|
57
|
+
from hydraflow.run_collection import filter_runs
|
59
58
|
|
60
59
|
assert len(run_list) == 6
|
61
60
|
x = filter_runs(run_list, {"q": 0})
|
@@ -65,28 +64,28 @@ def test_filter_all(run_list: list[Run]):
|
|
65
64
|
|
66
65
|
|
67
66
|
def test_filter_list(run_list: list[Run]):
|
68
|
-
from hydraflow.
|
67
|
+
from hydraflow.run_collection import filter_runs
|
69
68
|
|
70
69
|
x = filter_runs(run_list, p=[0, 4, 5])
|
71
70
|
assert len(x) == 3
|
72
71
|
|
73
72
|
|
74
73
|
def test_filter_tuple(run_list: list[Run]):
|
75
|
-
from hydraflow.
|
74
|
+
from hydraflow.run_collection import filter_runs
|
76
75
|
|
77
76
|
x = filter_runs(run_list, p=(1, 3))
|
78
77
|
assert len(x) == 2
|
79
78
|
|
80
79
|
|
81
80
|
def test_filter_invalid_param(run_list: list[Run]):
|
82
|
-
from hydraflow.
|
81
|
+
from hydraflow.run_collection import filter_runs
|
83
82
|
|
84
83
|
x = filter_runs(run_list, {"invalid": 0})
|
85
84
|
assert len(x) == 6
|
86
85
|
|
87
86
|
|
88
87
|
def test_find_run(run_list: list[Run]):
|
89
|
-
from hydraflow.
|
88
|
+
from hydraflow.run_collection import find_run, try_find_run
|
90
89
|
|
91
90
|
x = find_run(run_list, {"r": 1})
|
92
91
|
assert isinstance(x, Run)
|
@@ -100,20 +99,20 @@ def test_find_run(run_list: list[Run]):
|
|
100
99
|
|
101
100
|
|
102
101
|
def test_find_run_none(run_list: list[Run]):
|
103
|
-
from hydraflow.
|
102
|
+
from hydraflow.run_collection import find_run
|
104
103
|
|
105
104
|
with pytest.raises(ValueError):
|
106
105
|
find_run(run_list, {"r": 10})
|
107
106
|
|
108
107
|
|
109
108
|
def test_try_find_run_none_empty(run_list: list[Run]):
|
110
|
-
from hydraflow.
|
109
|
+
from hydraflow.run_collection import try_find_run
|
111
110
|
|
112
111
|
assert try_find_run([]) is None
|
113
112
|
|
114
113
|
|
115
114
|
def test_find_last_run(run_list: list[Run]):
|
116
|
-
from hydraflow.
|
115
|
+
from hydraflow.run_collection import find_last_run, try_find_last_run
|
117
116
|
|
118
117
|
x = find_last_run(run_list, {"r": 1})
|
119
118
|
assert isinstance(x, Run)
|
@@ -127,20 +126,20 @@ def test_find_last_run(run_list: list[Run]):
|
|
127
126
|
|
128
127
|
|
129
128
|
def test_find_last_run_none(run_list: list[Run]):
|
130
|
-
from hydraflow.
|
129
|
+
from hydraflow.run_collection import find_last_run
|
131
130
|
|
132
131
|
with pytest.raises(ValueError):
|
133
132
|
find_last_run(run_list, {"r": 10})
|
134
133
|
|
135
134
|
|
136
135
|
def test_try_find_last_run_none(run_list: list[Run]):
|
137
|
-
from hydraflow.
|
136
|
+
from hydraflow.run_collection import try_find_last_run
|
138
137
|
|
139
138
|
assert try_find_last_run([]) is None
|
140
139
|
|
141
140
|
|
142
141
|
def test_get_run(run_list: list[Run]):
|
143
|
-
from hydraflow.
|
142
|
+
from hydraflow.run_collection import get_run
|
144
143
|
|
145
144
|
run = get_run(run_list, {"p": 4})
|
146
145
|
assert isinstance(run, Run)
|
@@ -148,7 +147,7 @@ def test_get_run(run_list: list[Run]):
|
|
148
147
|
|
149
148
|
|
150
149
|
def test_get_run_error(run_list: list[Run]):
|
151
|
-
from hydraflow.
|
150
|
+
from hydraflow.run_collection import get_run
|
152
151
|
|
153
152
|
with pytest.raises(ValueError):
|
154
153
|
get_run(run_list, {"q": 0})
|
@@ -158,20 +157,30 @@ def test_get_run_error(run_list: list[Run]):
|
|
158
157
|
|
159
158
|
|
160
159
|
def test_try_get_run_none(run_list: list[Run]):
|
161
|
-
from hydraflow.
|
160
|
+
from hydraflow.run_collection import try_get_run
|
162
161
|
|
163
162
|
assert try_get_run(run_list, {"q": -1}) is None
|
164
163
|
|
165
164
|
|
166
165
|
def test_try_get_run_error(run_list: list[Run]):
|
167
|
-
from hydraflow.
|
166
|
+
from hydraflow.run_collection import try_get_run
|
168
167
|
|
169
168
|
with pytest.raises(ValueError):
|
170
169
|
try_get_run(run_list, {"q": 0})
|
171
170
|
|
172
171
|
|
172
|
+
def test_get_params(run_list: list[Run]):
|
173
|
+
from hydraflow.run_collection import get_params
|
174
|
+
|
175
|
+
assert get_params(run_list[1], "p") == ("1",)
|
176
|
+
assert get_params(run_list[2], "p", "q") == ("2", "0")
|
177
|
+
assert get_params(run_list[3], ["p", "q"]) == ("3", "0")
|
178
|
+
assert get_params(run_list[4], "p", ["q", "r"]) == ("4", "0", "1")
|
179
|
+
assert get_params(run_list[5], ["a", "q"], "r") == (None, "None", "2")
|
180
|
+
|
181
|
+
|
173
182
|
def test_get_param_names(run_list: list[Run]):
|
174
|
-
from hydraflow.
|
183
|
+
from hydraflow.run_collection import get_param_names
|
175
184
|
|
176
185
|
params = get_param_names(run_list)
|
177
186
|
assert len(params) == 3
|
@@ -181,7 +190,7 @@ def test_get_param_names(run_list: list[Run]):
|
|
181
190
|
|
182
191
|
|
183
192
|
def test_get_param_dict(run_list: list[Run]):
|
184
|
-
from hydraflow.
|
193
|
+
from hydraflow.run_collection import get_param_dict
|
185
194
|
|
186
195
|
params = get_param_dict(run_list)
|
187
196
|
assert len(params["p"]) == 6
|
@@ -250,7 +259,7 @@ def test_runs_filter(runs: RunCollection):
|
|
250
259
|
|
251
260
|
|
252
261
|
def test_runs_get(runs: RunCollection):
|
253
|
-
from hydraflow.
|
262
|
+
from hydraflow.run_collection import Run
|
254
263
|
|
255
264
|
run = runs.get({"p": 4})
|
256
265
|
assert isinstance(run, Run)
|
@@ -283,7 +292,7 @@ def test_runs_get_params_dict(runs: RunCollection):
|
|
283
292
|
|
284
293
|
|
285
294
|
def test_runs_find(runs: RunCollection):
|
286
|
-
from hydraflow.
|
295
|
+
from hydraflow.run_collection import Run
|
287
296
|
|
288
297
|
run = runs.find({"r": 0})
|
289
298
|
assert isinstance(run, Run)
|
@@ -304,7 +313,7 @@ def test_runs_try_find_none(runs: RunCollection):
|
|
304
313
|
|
305
314
|
|
306
315
|
def test_runs_find_last(runs: RunCollection):
|
307
|
-
from hydraflow.
|
316
|
+
from hydraflow.run_collection import Run
|
308
317
|
|
309
318
|
run = runs.find_last({"r": 0})
|
310
319
|
assert isinstance(run, Run)
|
@@ -333,7 +342,7 @@ def runs2(monkeypatch, tmp_path):
|
|
333
342
|
|
334
343
|
|
335
344
|
def test_list_runs(runs, runs2):
|
336
|
-
from hydraflow.
|
345
|
+
from hydraflow.run_collection import list_runs
|
337
346
|
|
338
347
|
mlflow.set_experiment("test_run")
|
339
348
|
all_runs = list_runs()
|
@@ -345,7 +354,7 @@ def test_list_runs(runs, runs2):
|
|
345
354
|
|
346
355
|
|
347
356
|
def test_list_runs_empty_list(runs, runs2):
|
348
|
-
from hydraflow.
|
357
|
+
from hydraflow.run_collection import list_runs
|
349
358
|
|
350
359
|
all_runs = list_runs([])
|
351
360
|
assert len(all_runs) == 9
|
@@ -353,14 +362,14 @@ def test_list_runs_empty_list(runs, runs2):
|
|
353
362
|
|
354
363
|
@pytest.mark.parametrize(["name", "n"], [("test_run", 6), ("test_run2", 3)])
|
355
364
|
def test_list_runs_list(runs, runs2, name, n):
|
356
|
-
from hydraflow.
|
365
|
+
from hydraflow.run_collection import list_runs
|
357
366
|
|
358
367
|
filtered_runs = list_runs(experiment_names=[name])
|
359
368
|
assert len(filtered_runs) == n
|
360
369
|
|
361
370
|
|
362
371
|
def test_list_runs_none(runs, runs2):
|
363
|
-
from hydraflow.
|
372
|
+
from hydraflow.run_collection import list_runs
|
364
373
|
|
365
374
|
no_runs = list_runs(experiment_names=["non_existent_experiment"])
|
366
375
|
assert len(no_runs) == 0
|
@@ -372,16 +381,20 @@ def test_run_collection_map(runs: RunCollection):
|
|
372
381
|
assert all(isinstance(run_id, str) for run_id in results)
|
373
382
|
|
374
383
|
|
384
|
+
def test_run_collection_map_args(runs: RunCollection):
|
385
|
+
results = list(runs.map(lambda run, x: run.info.run_id + x, "test"))
|
386
|
+
assert all(x.endswith("test") for x in results)
|
387
|
+
|
388
|
+
|
375
389
|
def test_run_collection_map_run_id(runs: RunCollection):
|
376
390
|
results = list(runs.map_run_id(lambda run_id: run_id))
|
377
391
|
assert len(results) == len(runs._runs)
|
378
392
|
assert all(isinstance(run_id, str) for run_id in results)
|
379
393
|
|
380
394
|
|
381
|
-
def
|
382
|
-
results = list(runs.
|
383
|
-
assert
|
384
|
-
assert all(isinstance(config, DictConfig) for config in results)
|
395
|
+
def test_run_collection_map_run_id_kwargs(runs: RunCollection):
|
396
|
+
results = list(runs.map_run_id(lambda run_id, x: x + run_id, x="test"))
|
397
|
+
assert all(x.startswith("test") for x in results)
|
385
398
|
|
386
399
|
|
387
400
|
def test_run_collection_map_uri(runs: RunCollection):
|
@@ -391,9 +404,10 @@ def test_run_collection_map_uri(runs: RunCollection):
|
|
391
404
|
|
392
405
|
|
393
406
|
def test_run_collection_map_dir(runs: RunCollection):
|
394
|
-
results = list(runs.map_dir(lambda dir_path: dir_path))
|
407
|
+
results = list(runs.map_dir(lambda dir_path, x: dir_path / x, "a.csv"))
|
395
408
|
assert len(results) == len(runs._runs)
|
396
|
-
assert all(isinstance(dir_path,
|
409
|
+
assert all(isinstance(dir_path, Path) for dir_path in results)
|
410
|
+
assert all(dir_path.stem == "a" for dir_path in results)
|
397
411
|
|
398
412
|
|
399
413
|
def test_run_collection_sort(runs: RunCollection):
|
@@ -427,15 +441,53 @@ def test_run_collection_group_by(runs: RunCollection):
|
|
427
441
|
assert grouped[("0",)][0] == runs[0]
|
428
442
|
assert grouped[("1",)][0] == runs[1]
|
429
443
|
|
430
|
-
grouped = runs.group_by(
|
444
|
+
grouped = runs.group_by("q")
|
431
445
|
assert len(grouped) == 2
|
432
446
|
|
433
|
-
grouped = runs.group_by(
|
447
|
+
grouped = runs.group_by("r")
|
434
448
|
assert len(grouped) == 3
|
435
449
|
|
436
450
|
|
437
|
-
|
438
|
-
|
451
|
+
def test_filter_runs_empty_list():
|
452
|
+
from hydraflow.run_collection import filter_runs
|
453
|
+
|
454
|
+
x = filter_runs([], p=[0, 1, 2])
|
455
|
+
assert x == []
|
456
|
+
|
457
|
+
|
458
|
+
def test_filter_runs_no_match(run_list: list[Run]):
|
459
|
+
from hydraflow.run_collection import filter_runs
|
460
|
+
|
461
|
+
x = filter_runs(run_list, p=[10, 11, 12])
|
462
|
+
assert x == []
|
463
|
+
|
464
|
+
|
465
|
+
def test_get_run_no_match(run_list: list[Run]):
|
466
|
+
from hydraflow.run_collection import get_run
|
467
|
+
|
468
|
+
with pytest.raises(ValueError):
|
469
|
+
get_run(run_list, {"p": 10})
|
470
|
+
|
471
|
+
|
472
|
+
def test_get_run_multiple_params(run_list: list[Run]):
|
473
|
+
from hydraflow.run_collection import get_run
|
474
|
+
|
475
|
+
run = get_run(run_list, {"p": 4, "q": 0})
|
476
|
+
assert isinstance(run, Run)
|
477
|
+
assert run.data.params["p"] == "4"
|
478
|
+
assert run.data.params["q"] == "0"
|
479
|
+
|
439
480
|
|
440
|
-
|
441
|
-
|
481
|
+
def test_try_get_run_no_match(run_list: list[Run]):
|
482
|
+
from hydraflow.run_collection import try_get_run
|
483
|
+
|
484
|
+
assert try_get_run(run_list, {"p": 10}) is None
|
485
|
+
|
486
|
+
|
487
|
+
def test_try_get_run_multiple_params(run_list: list[Run]):
|
488
|
+
from hydraflow.run_collection import try_get_run
|
489
|
+
|
490
|
+
run = try_get_run(run_list, {"p": 4, "q": 0})
|
491
|
+
assert isinstance(run, Run)
|
492
|
+
assert run.data.params["p"] == "4"
|
493
|
+
assert run.data.params["q"] == "0"
|
@@ -1,6 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import subprocess
|
4
|
+
import time
|
4
5
|
from pathlib import Path
|
5
6
|
|
6
7
|
import pytest
|
@@ -21,6 +22,7 @@ def test_watch(dir, monkeypatch, tmp_path):
|
|
21
22
|
|
22
23
|
with watch(func, dir if isinstance(dir, str) else dir()):
|
23
24
|
subprocess.check_call(["python", file])
|
25
|
+
time.sleep(1)
|
24
26
|
|
25
|
-
assert results[0][0] == "watch.txt"
|
26
|
-
assert results[0][1] == "watch"
|
27
|
+
assert results[0][0] == "watch.txt" # type: ignore
|
28
|
+
assert results[0][1] == "watch" # type: ignore
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|