hydraflow 0.3.2__tar.gz → 0.3.4__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.3.2 → hydraflow-0.3.4}/PKG-INFO +1 -1
- {hydraflow-0.3.2 → hydraflow-0.3.4}/pyproject.toml +1 -1
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/__init__.py +11 -3
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/config.py +13 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/context.py +1 -1
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/param.py +4 -1
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/run_collection.py +3 -13
- hydraflow-0.3.4/src/hydraflow/run_data.py +57 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/utils.py +25 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/app.py +4 -2
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_app.py +19 -39
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_config.py +21 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_log_run.py +13 -1
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_run_collection.py +17 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_run_data.py +6 -7
- hydraflow-0.3.2/src/hydraflow/run_data.py +0 -34
- {hydraflow-0.3.2 → hydraflow-0.3.4}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/.gitattributes +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/.gitignore +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/LICENSE +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/README.md +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/apps/quickstart.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/mkdocs.yml +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/asyncio.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/mlflow.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/progress.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/py.typed +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/run_info.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/__init__.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/__init__.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/progress.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/watch.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_asyncio.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_context.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_mlflow.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_param.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_progress.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_run_info.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_version.py +0 -0
- {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.4
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://github.com/daizutabi/hydraflow
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -1,19 +1,27 @@
|
|
1
1
|
"""Integrate Hydra and MLflow to manage and track machine learning experiments."""
|
2
2
|
|
3
|
-
from .context import chdir_artifact,
|
3
|
+
from .context import chdir_artifact, chdir_hydra_output, log_run, start_run, watch
|
4
4
|
from .mlflow import list_runs, search_runs, set_experiment
|
5
5
|
from .progress import multi_tasks_progress, parallel_progress
|
6
6
|
from .run_collection import RunCollection
|
7
|
-
from .utils import
|
7
|
+
from .utils import (
|
8
|
+
get_artifact_dir,
|
9
|
+
get_hydra_output_dir,
|
10
|
+
get_overrides,
|
11
|
+
load_config,
|
12
|
+
load_overrides,
|
13
|
+
)
|
8
14
|
|
9
15
|
__all__ = [
|
10
16
|
"RunCollection",
|
11
17
|
"chdir_artifact",
|
12
|
-
"
|
18
|
+
"chdir_hydra_output",
|
13
19
|
"get_artifact_dir",
|
14
20
|
"get_hydra_output_dir",
|
21
|
+
"get_overrides",
|
15
22
|
"list_runs",
|
16
23
|
"load_config",
|
24
|
+
"load_overrides",
|
17
25
|
"log_run",
|
18
26
|
"multi_tasks_progress",
|
19
27
|
"parallel_progress",
|
@@ -44,12 +44,25 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
44
44
|
if config is None:
|
45
45
|
return
|
46
46
|
|
47
|
+
if isinstance(config, list) and all(isinstance(x, str) for x in config):
|
48
|
+
config = _from_dotlist(config)
|
49
|
+
|
47
50
|
if not isinstance(config, DictConfig | ListConfig):
|
48
51
|
config = OmegaConf.create(config) # type: ignore
|
49
52
|
|
50
53
|
yield from _iter_params(config, prefix)
|
51
54
|
|
52
55
|
|
56
|
+
def _from_dotlist(config: list[str]) -> dict[str, str]:
|
57
|
+
result = {}
|
58
|
+
for item in config:
|
59
|
+
if "=" in item:
|
60
|
+
key, value = item.split("=", 1)
|
61
|
+
result[key.strip()] = value.strip()
|
62
|
+
|
63
|
+
return result
|
64
|
+
|
65
|
+
|
53
66
|
def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
54
67
|
if isinstance(config, DictConfig):
|
55
68
|
for key, value in config.items():
|
@@ -239,7 +239,7 @@ class Handler(PatternMatchingEventHandler):
|
|
239
239
|
|
240
240
|
|
241
241
|
@contextmanager
|
242
|
-
def
|
242
|
+
def chdir_hydra_output() -> Iterator[Path]:
|
243
243
|
"""Change the current working directory to the hydra output directory.
|
244
244
|
|
245
245
|
This context manager changes the current working directory to the hydra output
|
@@ -34,7 +34,10 @@ def match(param: str, value: Any) -> bool:
|
|
34
34
|
if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
|
35
35
|
return m
|
36
36
|
|
37
|
-
if isinstance(value,
|
37
|
+
if isinstance(value, str):
|
38
|
+
return param == value
|
39
|
+
|
40
|
+
if isinstance(value, int | float):
|
38
41
|
return type(value)(param) == value
|
39
42
|
|
40
43
|
return param == str(value)
|
@@ -24,12 +24,12 @@ from itertools import chain
|
|
24
24
|
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
25
25
|
|
26
26
|
from mlflow.entities import RunStatus
|
27
|
-
from polars.dataframe import DataFrame
|
28
27
|
|
29
28
|
import hydraflow.param
|
30
|
-
from hydraflow.config import
|
29
|
+
from hydraflow.config import iter_params
|
31
30
|
from hydraflow.run_data import RunCollectionData
|
32
31
|
from hydraflow.run_info import RunCollectionInfo
|
32
|
+
from hydraflow.utils import load_config
|
33
33
|
|
34
34
|
if TYPE_CHECKING:
|
35
35
|
from collections.abc import Callable, Iterator
|
@@ -516,7 +516,7 @@ class RunCollection:
|
|
516
516
|
in the collection.
|
517
517
|
|
518
518
|
"""
|
519
|
-
return (func(
|
519
|
+
return (func(load_config(run), *args, **kwargs) for run in self)
|
520
520
|
|
521
521
|
def map_uri(
|
522
522
|
self,
|
@@ -599,16 +599,6 @@ class RunCollection:
|
|
599
599
|
|
600
600
|
return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
|
601
601
|
|
602
|
-
@property
|
603
|
-
def config(self) -> DataFrame:
|
604
|
-
"""Get the runs' configurations as a polars DataFrame.
|
605
|
-
|
606
|
-
Returns:
|
607
|
-
A polars DataFrame containing the runs' configurations.
|
608
|
-
|
609
|
-
"""
|
610
|
-
return DataFrame(self.map_config(collect_params))
|
611
|
-
|
612
602
|
|
613
603
|
def _param_matches(run: Run, key: str, value: Any) -> bool:
|
614
604
|
params = run.data.params
|
@@ -0,0 +1,57 @@
|
|
1
|
+
"""Provide data about `RunCollection` instances."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
from polars.dataframe import DataFrame
|
8
|
+
|
9
|
+
from hydraflow.config import collect_params
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from collections.abc import Iterable
|
13
|
+
from typing import Any
|
14
|
+
|
15
|
+
from hydraflow.run_collection import RunCollection
|
16
|
+
|
17
|
+
|
18
|
+
class RunCollectionData:
|
19
|
+
"""Provide data about a `RunCollection` instance."""
|
20
|
+
|
21
|
+
def __init__(self, runs: RunCollection) -> None:
|
22
|
+
self._runs = runs
|
23
|
+
|
24
|
+
@property
|
25
|
+
def params(self) -> dict[str, list[str]]:
|
26
|
+
"""Get the parameters for each run in the collection."""
|
27
|
+
return _to_dict(run.data.params for run in self._runs)
|
28
|
+
|
29
|
+
@property
|
30
|
+
def metrics(self) -> dict[str, list[float]]:
|
31
|
+
"""Get the metrics for each run in the collection."""
|
32
|
+
return _to_dict(run.data.metrics for run in self._runs)
|
33
|
+
|
34
|
+
@property
|
35
|
+
def config(self) -> DataFrame:
|
36
|
+
"""Get the runs' configurations as a polars DataFrame.
|
37
|
+
|
38
|
+
Returns:
|
39
|
+
A polars DataFrame containing the runs' configurations.
|
40
|
+
|
41
|
+
"""
|
42
|
+
return DataFrame(self._runs.map_config(collect_params))
|
43
|
+
|
44
|
+
|
45
|
+
def _to_dict(it: Iterable[dict[str, Any]]) -> dict[str, list[Any]]:
|
46
|
+
"""Convert an iterable of dictionaries to a dictionary of lists."""
|
47
|
+
data = list(it)
|
48
|
+
if not data:
|
49
|
+
return {}
|
50
|
+
|
51
|
+
keys = []
|
52
|
+
for d in data:
|
53
|
+
for key in d:
|
54
|
+
if key not in keys:
|
55
|
+
keys.append(key)
|
56
|
+
|
57
|
+
return {key: [x.get(key) for x in data] for key in keys}
|
@@ -86,3 +86,28 @@ def load_config(run: Run) -> DictConfig:
|
|
86
86
|
"""
|
87
87
|
path = get_artifact_dir(run) / ".hydra/config.yaml"
|
88
88
|
return OmegaConf.load(path) # type: ignore
|
89
|
+
|
90
|
+
|
91
|
+
def get_overrides() -> list[str]:
|
92
|
+
"""Retrieve the overrides for the current run."""
|
93
|
+
return HydraConfig.get().overrides.task
|
94
|
+
|
95
|
+
|
96
|
+
def load_overrides(run: Run) -> list[str]:
|
97
|
+
"""Load the overrides for a given run.
|
98
|
+
|
99
|
+
This function loads the overrides for the provided Run instance
|
100
|
+
by downloading the overrides file from the MLflow artifacts and
|
101
|
+
loading it using OmegaConf. It returns an empty config if
|
102
|
+
`.hydra/overrides.yaml` is not found in the run's artifact directory.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
run (Run): The Run instance for which to load the overrides.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
The loaded overrides as a list of strings. Returns an empty list
|
109
|
+
if the overrides file is not found.
|
110
|
+
|
111
|
+
"""
|
112
|
+
path = get_artifact_dir(run) / ".hydra/overrides.yaml"
|
113
|
+
return [str(x) for x in OmegaConf.load(path)]
|
@@ -27,11 +27,11 @@ cs.store(name="config", node=MySQLConfig)
|
|
27
27
|
|
28
28
|
@hydra.main(version_base=None, config_name="config")
|
29
29
|
def app(cfg: MySQLConfig):
|
30
|
-
with hydraflow.
|
30
|
+
with hydraflow.chdir_hydra_output() as path:
|
31
31
|
Path("chdir_hydra.txt").write_text(path.as_posix())
|
32
32
|
|
33
33
|
hydraflow.set_experiment(prefix="_", suffix="_")
|
34
|
-
with hydraflow.start_run(cfg):
|
34
|
+
with hydraflow.start_run(cfg) as run:
|
35
35
|
log.info(f"START, {cfg.host}, {cfg.port} ")
|
36
36
|
|
37
37
|
artifact_dir = hydraflow.get_artifact_dir()
|
@@ -50,6 +50,8 @@ def app(cfg: MySQLConfig):
|
|
50
50
|
if cfg.host == "x":
|
51
51
|
mlflow.log_metric("m", cfg.port + 10, 2)
|
52
52
|
|
53
|
+
assert hydraflow.get_overrides() == hydraflow.load_overrides(run)
|
54
|
+
|
53
55
|
log.info("END")
|
54
56
|
|
55
57
|
|
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
|
|
8
8
|
import mlflow
|
9
9
|
import pytest
|
10
10
|
from mlflow.entities import RunStatus
|
11
|
-
from omegaconf import
|
11
|
+
from omegaconf import OmegaConf
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from omegaconf import DictConfig
|
@@ -92,33 +92,30 @@ def test_app_info_run_id(rc: RunCollection):
|
|
92
92
|
|
93
93
|
def test_app_data_params(rc: RunCollection):
|
94
94
|
params = rc.data.params
|
95
|
-
assert params[
|
96
|
-
assert params[
|
97
|
-
assert params[
|
98
|
-
assert params[3] == {"port": "2", "host": "y", "values": "[1, 2, 3]"}
|
95
|
+
assert params["port"] == ["1", "2", "1", "2"]
|
96
|
+
assert params["host"] == ["x", "x", "y", "y"]
|
97
|
+
assert params["values"] == ["[1, 2, 3]", "[1, 2, 3]", "[1, 2, 3]", "[1, 2, 3]"]
|
99
98
|
|
100
99
|
|
101
100
|
def test_app_data_metrics(rc: RunCollection):
|
102
101
|
metrics = rc.data.metrics
|
103
|
-
assert metrics[
|
104
|
-
assert metrics[
|
105
|
-
assert metrics[2] == {"m": 2, "watch": 3}
|
106
|
-
assert metrics[3] == {"m": 3, "watch": 3}
|
102
|
+
assert metrics["m"] == [11, 12, 2, 3]
|
103
|
+
assert metrics["watch"] == [3, 3, 3, 3]
|
107
104
|
|
108
105
|
|
109
106
|
def test_app_data_config(rc: RunCollection):
|
110
107
|
config = rc.data.config
|
111
|
-
assert config[
|
112
|
-
assert config[
|
113
|
-
assert config[2].host == "y"
|
114
|
-
assert config[3].host == "y"
|
108
|
+
assert config["port"].to_list() == [1, 2, 1, 2]
|
109
|
+
assert config["host"].to_list() == ["x", "x", "y", "y"]
|
115
110
|
|
116
111
|
|
117
112
|
def test_app_data_config_list(rc: RunCollection):
|
118
113
|
config = rc.data.config
|
119
|
-
|
120
|
-
assert
|
121
|
-
|
114
|
+
values = config["values"].to_list()
|
115
|
+
assert str(config.select("values").dtypes) == "[List(Int64)]"
|
116
|
+
for x in values:
|
117
|
+
assert isinstance(x, list)
|
118
|
+
assert x == [1, 2, 3]
|
122
119
|
|
123
120
|
|
124
121
|
def test_app_info_artifact_uri(rc: RunCollection):
|
@@ -160,14 +157,12 @@ def test_app_map_config(rc: RunCollection):
|
|
160
157
|
def test_app_group_by(rc: RunCollection):
|
161
158
|
grouped = rc.group_by("host")
|
162
159
|
assert len(grouped) == 2
|
163
|
-
|
164
|
-
assert grouped["x"].data.params[
|
165
|
-
|
166
|
-
assert grouped["
|
167
|
-
|
168
|
-
assert grouped["y"].data.params[
|
169
|
-
x = {"port": "2", "host": "y", "values": "[1, 2, 3]"}
|
170
|
-
assert grouped["y"].data.params[1] == x
|
160
|
+
assert grouped["x"].data.params["port"] == ["1", "2"]
|
161
|
+
assert grouped["x"].data.params["host"] == ["x", "x"]
|
162
|
+
assert grouped["x"].data.params["values"] == ["[1, 2, 3]", "[1, 2, 3]"]
|
163
|
+
assert grouped["y"].data.params["port"] == ["1", "2"]
|
164
|
+
assert grouped["y"].data.params["host"] == ["y", "y"]
|
165
|
+
assert grouped["y"].data.params["values"] == ["[1, 2, 3]", "[1, 2, 3]"]
|
171
166
|
|
172
167
|
|
173
168
|
def test_app_group_by_list(rc: RunCollection):
|
@@ -184,18 +179,3 @@ def test_app_filter_list(rc: RunCollection):
|
|
184
179
|
assert len(filtered) == 4
|
185
180
|
filtered = rc.filter(values=[1])
|
186
181
|
assert not filtered
|
187
|
-
|
188
|
-
|
189
|
-
def test_config(rc: RunCollection):
|
190
|
-
df = rc.config
|
191
|
-
assert df.columns == ["host", "port", "values"]
|
192
|
-
assert df.shape == (4, 3)
|
193
|
-
assert df.select("host").to_series().to_list() == ["x", "x", "y", "y"]
|
194
|
-
assert df.select("port").to_series().to_list() == [1, 2, 1, 2]
|
195
|
-
assert str(df.select("values").dtypes) == "[List(Int64)]"
|
196
|
-
assert df.select("values").to_series().to_list() == [
|
197
|
-
[1, 2, 3],
|
198
|
-
[1, 2, 3],
|
199
|
-
[1, 2, 3],
|
200
|
-
[1, 2, 3],
|
201
|
-
]
|
@@ -205,3 +205,24 @@ def test_list_config_str(s):
|
|
205
205
|
assert isinstance(b, ListConfig)
|
206
206
|
t = OmegaConf.create(json.loads(s))
|
207
207
|
assert b == t
|
208
|
+
|
209
|
+
|
210
|
+
@pytest.mark.parametrize("x", [{"a": 1}, {"a": [1, 2, 3]}])
|
211
|
+
def test_collect_params_dict(x):
|
212
|
+
from hydraflow.config import collect_params
|
213
|
+
|
214
|
+
assert collect_params(x) == x
|
215
|
+
|
216
|
+
|
217
|
+
def test_collect_params_dict_dot():
|
218
|
+
from hydraflow.config import collect_params
|
219
|
+
|
220
|
+
assert collect_params({"a": {"b": 1}}) == {"a.b": 1}
|
221
|
+
assert collect_params({"a.b": 1}) == {"a.b": 1}
|
222
|
+
|
223
|
+
|
224
|
+
def test_collect_params_list_dot():
|
225
|
+
from hydraflow.config import collect_params
|
226
|
+
|
227
|
+
assert collect_params(["a=1"]) == {"a": "1"}
|
228
|
+
assert collect_params(["a.b=2", "c"]) == {"a.b": "2"}
|
@@ -50,7 +50,7 @@ def read_log(run_id: str, path: str) -> str:
|
|
50
50
|
|
51
51
|
|
52
52
|
def test_load_config(run: Run):
|
53
|
-
from hydraflow.
|
53
|
+
from hydraflow.utils import load_config
|
54
54
|
|
55
55
|
log = read_log(run.info.run_id, "log_run.log")
|
56
56
|
assert "START" in log
|
@@ -63,6 +63,18 @@ def test_load_config(run: Run):
|
|
63
63
|
assert cfg.port == int(port)
|
64
64
|
|
65
65
|
|
66
|
+
def test_load_overrides(run: Run):
|
67
|
+
from hydraflow.utils import load_overrides
|
68
|
+
|
69
|
+
log = read_log(run.info.run_id, "log_run.log")
|
70
|
+
assert "START" in log
|
71
|
+
assert "END" in log
|
72
|
+
|
73
|
+
host, port = log.splitlines()[0].split("START,")[-1].split(",")
|
74
|
+
|
75
|
+
assert load_overrides(run) == [f"host={host.strip()}", f"port={port.strip()}"]
|
76
|
+
|
77
|
+
|
66
78
|
def test_info(run: Run):
|
67
79
|
log = read_log(run.info.run_id, "artifact_dir.txt")
|
68
80
|
a, b = log.split(" ")
|
@@ -67,6 +67,8 @@ def test_filter_one(run_list: list[Run]):
|
|
67
67
|
assert len(x) == 1
|
68
68
|
x = filter_runs(run_list, p=1)
|
69
69
|
assert len(x) == 1
|
70
|
+
x = filter_runs(run_list, ["p=1"])
|
71
|
+
assert len(x) == 1
|
70
72
|
|
71
73
|
|
72
74
|
def test_filter_all(run_list: list[Run]):
|
@@ -77,6 +79,8 @@ def test_filter_all(run_list: list[Run]):
|
|
77
79
|
assert len(x) == 5
|
78
80
|
x = filter_runs(run_list, q=0)
|
79
81
|
assert len(x) == 5
|
82
|
+
x = filter_runs(run_list, ["q=0"])
|
83
|
+
assert len(x) == 5
|
80
84
|
|
81
85
|
|
82
86
|
def test_filter_list(run_list: list[Run]):
|
@@ -98,6 +102,8 @@ def test_filter_invalid_param(run_list: list[Run]):
|
|
98
102
|
|
99
103
|
x = filter_runs(run_list, {"invalid": 0})
|
100
104
|
assert len(x) == 6
|
105
|
+
x = filter_runs(run_list, ["invalid=0"])
|
106
|
+
assert len(x) == 6
|
101
107
|
|
102
108
|
|
103
109
|
def test_filter_status(run_list: list[Run]):
|
@@ -181,15 +187,20 @@ def test_filter(rc: RunCollection):
|
|
181
187
|
assert len(rc.filter()) == 6
|
182
188
|
assert len(rc.filter({})) == 6
|
183
189
|
assert len(rc.filter({"p": 1})) == 1
|
190
|
+
assert len(rc.filter(["p=1"])) == 1
|
184
191
|
assert len(rc.filter({"q": 0})) == 5
|
192
|
+
assert len(rc.filter(["q=0"])) == 5
|
185
193
|
assert len(rc.filter({"q": -1})) == 0
|
194
|
+
assert len(rc.filter(["q=-1"])) == 0
|
186
195
|
assert not rc.filter({"q": -1})
|
187
196
|
assert len(rc.filter(p=5)) == 1
|
188
197
|
assert len(rc.filter(q=0)) == 5
|
189
198
|
assert len(rc.filter(q=-1)) == 0
|
190
199
|
assert not rc.filter(q=-1)
|
191
200
|
assert len(rc.filter({"r": 2})) == 2
|
201
|
+
assert len(rc.filter(["r=2"])) == 2
|
192
202
|
assert len(rc.filter(r=0)) == 2
|
203
|
+
assert len(rc.filter(["r=0"])) == 2
|
193
204
|
|
194
205
|
|
195
206
|
def test_get(rc: RunCollection):
|
@@ -197,15 +208,21 @@ def test_get(rc: RunCollection):
|
|
197
208
|
assert isinstance(run, Run)
|
198
209
|
run = rc.get(p=2)
|
199
210
|
assert isinstance(run, Run)
|
211
|
+
run = rc.get(["p=3"])
|
212
|
+
assert isinstance(run, Run)
|
200
213
|
|
201
214
|
|
202
215
|
def test_try_get(rc: RunCollection):
|
203
216
|
run = rc.try_get({"p": 5})
|
204
217
|
assert isinstance(run, Run)
|
218
|
+
run = rc.try_get(["p=2"])
|
219
|
+
assert isinstance(run, Run)
|
205
220
|
run = rc.try_get(p=1)
|
206
221
|
assert isinstance(run, Run)
|
207
222
|
run = rc.try_get(p=-1)
|
208
223
|
assert run is None
|
224
|
+
run = rc.try_get(["p=-2"])
|
225
|
+
assert run is None
|
209
226
|
|
210
227
|
|
211
228
|
def test_get_param_names(rc: RunCollection):
|
@@ -26,18 +26,17 @@ def runs(monkeypatch, tmp_path):
|
|
26
26
|
|
27
27
|
|
28
28
|
def test_data_params(runs: RunCollection):
|
29
|
-
assert runs.data.params
|
29
|
+
assert runs.data.params["p"] == ["0", "1", "2"]
|
30
30
|
|
31
31
|
|
32
32
|
def test_data_metrics(runs: RunCollection):
|
33
33
|
m = runs.data.metrics
|
34
|
-
assert m[
|
35
|
-
assert m[
|
36
|
-
assert m[2] == {"metric1": 3, "metric2": 4}
|
34
|
+
assert m["metric1"] == [1, 2, 3]
|
35
|
+
assert m["metric2"] == [2, 3, 4]
|
37
36
|
|
38
37
|
|
39
38
|
def test_data_empty_run_collection():
|
40
39
|
rc = RunCollection([])
|
41
|
-
assert rc.data.params ==
|
42
|
-
assert rc.data.metrics ==
|
43
|
-
assert rc.data.config ==
|
40
|
+
assert rc.data.params == {}
|
41
|
+
assert rc.data.metrics == {}
|
42
|
+
assert len(rc.data.config) == 0
|
@@ -1,34 +0,0 @@
|
|
1
|
-
"""Provide data about `RunCollection` instances."""
|
2
|
-
|
3
|
-
from __future__ import annotations
|
4
|
-
|
5
|
-
from typing import TYPE_CHECKING
|
6
|
-
|
7
|
-
from hydraflow.utils import load_config
|
8
|
-
|
9
|
-
if TYPE_CHECKING:
|
10
|
-
from omegaconf import DictConfig
|
11
|
-
|
12
|
-
from hydraflow.run_collection import RunCollection
|
13
|
-
|
14
|
-
|
15
|
-
class RunCollectionData:
|
16
|
-
"""Provide data about a `RunCollection` instance."""
|
17
|
-
|
18
|
-
def __init__(self, runs: RunCollection) -> None:
|
19
|
-
self._runs = runs
|
20
|
-
|
21
|
-
@property
|
22
|
-
def params(self) -> list[dict[str, str]]:
|
23
|
-
"""Get the parameters for each run in the collection."""
|
24
|
-
return [run.data.params for run in self._runs]
|
25
|
-
|
26
|
-
@property
|
27
|
-
def metrics(self) -> list[dict[str, float]]:
|
28
|
-
"""Get the metrics for each run in the collection."""
|
29
|
-
return [run.data.metrics for run in self._runs]
|
30
|
-
|
31
|
-
@property
|
32
|
-
def config(self) -> list[DictConfig]:
|
33
|
-
"""Get the configuration for each run in the collection."""
|
34
|
-
return [load_config(run) for run in self._runs]
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|