hydraflow 0.1.1__tar.gz → 0.1.4__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.1.1 → hydraflow-0.1.4}/.devcontainer/devcontainer.json +1 -3
- {hydraflow-0.1.1 → hydraflow-0.1.4}/.devcontainer/starship.toml +1 -1
- {hydraflow-0.1.1 → hydraflow-0.1.4}/PKG-INFO +1 -1
- {hydraflow-0.1.1 → hydraflow-0.1.4}/pyproject.toml +2 -2
- {hydraflow-0.1.1 → hydraflow-0.1.4}/src/hydraflow/__init__.py +13 -5
- {hydraflow-0.1.1 → hydraflow-0.1.4}/src/hydraflow/context.py +9 -8
- {hydraflow-0.1.1 → hydraflow-0.1.4}/src/hydraflow/mlflow.py +6 -3
- hydraflow-0.1.4/src/hydraflow/runs.py +217 -0
- hydraflow-0.1.4/tests/scripts/__init__.py +0 -0
- {hydraflow-0.1.1/tests → hydraflow-0.1.4/tests/scripts}/watch.py +4 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/tests/test_log_run.py +3 -11
- hydraflow-0.1.1/tests/test_run.py → hydraflow-0.1.4/tests/test_runs.py +96 -24
- {hydraflow-0.1.1 → hydraflow-0.1.4}/tests/test_watch.py +1 -1
- hydraflow-0.1.1/src/hydraflow/run.py +0 -172
- {hydraflow-0.1.1 → hydraflow-0.1.4}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/.gitattributes +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/.gitignore +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/LICENSE +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/README.md +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/src/hydraflow/config.py +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/src/hydraflow/util.py +0 -0
- {hydraflow-0.1.1/tests → hydraflow-0.1.4/tests/scripts}/log_run.py +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/tests/test_config.py +0 -0
- {hydraflow-0.1.1 → hydraflow-0.1.4}/tests/test_version.py +0 -0
@@ -1,9 +1,7 @@
|
|
1
1
|
{
|
2
2
|
"name": "hydraflow",
|
3
|
-
"image": "mcr.microsoft.com/devcontainers/
|
3
|
+
"image": "mcr.microsoft.com/vscode/devcontainers/base:ubuntu-22.04",
|
4
4
|
"features": {
|
5
|
-
"ghcr.io/devcontainers-contrib/features/ruff:1": {},
|
6
|
-
"ghcr.io/devcontainers-contrib/features/hatch:2": {},
|
7
5
|
"ghcr.io/devcontainers-contrib/features/starship:1": {},
|
8
6
|
"ghcr.io/va-h/devcontainers-features/uv:1": {}
|
9
7
|
},
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.1.
|
7
|
+
version = "0.1.4"
|
8
8
|
description = "Hydra with MLflow"
|
9
9
|
readme = "README.md"
|
10
10
|
license = "MIT"
|
@@ -54,7 +54,7 @@ serve = "mkdocs serve --dev-addr localhost:8000 {args}"
|
|
54
54
|
deploy = "mkdocs gh-deploy --force"
|
55
55
|
|
56
56
|
[tool.ruff]
|
57
|
-
line-length =
|
57
|
+
line-length = 100
|
58
58
|
target-version = "py312"
|
59
59
|
|
60
60
|
[tool.ruff.lint]
|
@@ -1,27 +1,35 @@
|
|
1
1
|
from .context import Info, chdir_artifact, log_run, watch
|
2
2
|
from .mlflow import set_experiment
|
3
|
-
from .
|
4
|
-
|
3
|
+
from .runs import (
|
4
|
+
Run,
|
5
|
+
Runs,
|
6
|
+
drop_unique_params,
|
7
|
+
filter_runs,
|
5
8
|
get_artifact_dir,
|
6
9
|
get_artifact_path,
|
7
10
|
get_artifact_uri,
|
8
|
-
get_by_config,
|
9
11
|
get_param_dict,
|
10
12
|
get_param_names,
|
13
|
+
get_run,
|
11
14
|
get_run_id,
|
15
|
+
load_config,
|
12
16
|
)
|
13
17
|
|
14
18
|
__all__ = [
|
15
19
|
"Info",
|
20
|
+
"Run",
|
21
|
+
"Runs",
|
16
22
|
"chdir_artifact",
|
17
|
-
"
|
23
|
+
"drop_unique_params",
|
24
|
+
"filter_runs",
|
18
25
|
"get_artifact_dir",
|
19
26
|
"get_artifact_path",
|
20
27
|
"get_artifact_uri",
|
21
|
-
"get_by_config",
|
22
28
|
"get_param_dict",
|
23
29
|
"get_param_names",
|
30
|
+
"get_run",
|
24
31
|
"get_run_id",
|
32
|
+
"load_config",
|
25
33
|
"log_run",
|
26
34
|
"set_experiment",
|
27
35
|
"watch",
|
@@ -13,7 +13,7 @@ from watchdog.events import FileModifiedEvent, FileSystemEventHandler
|
|
13
13
|
from watchdog.observers import Observer
|
14
14
|
|
15
15
|
from hydraflow.mlflow import log_params
|
16
|
-
from hydraflow.
|
16
|
+
from hydraflow.runs import get_artifact_path
|
17
17
|
from hydraflow.util import uri_to_path
|
18
18
|
|
19
19
|
if TYPE_CHECKING:
|
@@ -40,14 +40,19 @@ def log_run(
|
|
40
40
|
hc = HydraConfig.get()
|
41
41
|
output_dir = Path(hc.runtime.output_dir)
|
42
42
|
uri = mlflow.get_artifact_uri()
|
43
|
-
|
43
|
+
info = Info(output_dir, uri_to_path(uri))
|
44
44
|
|
45
45
|
# Save '.hydra' config directory first.
|
46
46
|
output_subdir = output_dir / (hc.output_subdir or "")
|
47
47
|
mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
|
48
48
|
|
49
|
+
def log_artifact(path: Path) -> None:
|
50
|
+
local_path = (output_dir / path).as_posix()
|
51
|
+
mlflow.log_artifact(local_path)
|
52
|
+
|
49
53
|
try:
|
50
|
-
|
54
|
+
with watch(log_artifact, output_dir):
|
55
|
+
yield info
|
51
56
|
|
52
57
|
finally:
|
53
58
|
# Save output_dir including '.hydra' config directory.
|
@@ -55,11 +60,7 @@ def log_run(
|
|
55
60
|
|
56
61
|
|
57
62
|
@contextmanager
|
58
|
-
def watch(
|
59
|
-
func: Callable[[Path], None],
|
60
|
-
dir: Path | str = "",
|
61
|
-
timeout: int = 600,
|
62
|
-
) -> Iterator[None]:
|
63
|
+
def watch(func: Callable[[Path], None], dir: Path | str = "", timeout: int = 60) -> Iterator[None]:
|
63
64
|
if not dir:
|
64
65
|
uri = mlflow.get_artifact_uri()
|
65
66
|
dir = uri_to_path(uri)
|
@@ -6,10 +6,13 @@ from hydra.core.hydra_config import HydraConfig
|
|
6
6
|
from hydraflow.config import iter_params
|
7
7
|
|
8
8
|
|
9
|
-
def set_experiment() -> None:
|
9
|
+
def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -> None:
|
10
|
+
if uri:
|
11
|
+
mlflow.set_tracking_uri(uri)
|
12
|
+
|
10
13
|
hc = HydraConfig.get()
|
11
|
-
|
12
|
-
mlflow.set_experiment(
|
14
|
+
name = f"{prefix}{hc.job.name}{suffix}"
|
15
|
+
mlflow.set_experiment(name)
|
13
16
|
|
14
17
|
|
15
18
|
def log_params(config: object, *, synchronous: bool | None = None) -> None:
|
@@ -0,0 +1,217 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from functools import cache
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import TYPE_CHECKING, Any
|
7
|
+
|
8
|
+
import mlflow
|
9
|
+
import numpy as np
|
10
|
+
from mlflow.entities.run import Run as Run_
|
11
|
+
from mlflow.tracking import artifact_utils
|
12
|
+
from omegaconf import DictConfig, OmegaConf
|
13
|
+
from pandas import DataFrame, Series
|
14
|
+
|
15
|
+
from hydraflow.config import iter_params
|
16
|
+
from hydraflow.util import uri_to_path
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from typing import Any
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class Runs:
|
24
|
+
runs: list[Run_] | DataFrame
|
25
|
+
|
26
|
+
def __repr__(self) -> str:
|
27
|
+
return f"{self.__class__.__name__}({len(self)})"
|
28
|
+
|
29
|
+
def __len__(self) -> int:
|
30
|
+
return len(self.runs)
|
31
|
+
|
32
|
+
def filter(self, config: object) -> Runs:
|
33
|
+
return Runs(filter_runs(self.runs, config))
|
34
|
+
|
35
|
+
def get(self, config: object) -> Run:
|
36
|
+
return Run(get_run(self.runs, config))
|
37
|
+
|
38
|
+
def drop_unique_params(self) -> Runs:
|
39
|
+
if isinstance(self.runs, DataFrame):
|
40
|
+
return Runs(drop_unique_params(self.runs))
|
41
|
+
|
42
|
+
raise NotImplementedError
|
43
|
+
|
44
|
+
def get_param_names(self) -> list[str]:
|
45
|
+
if isinstance(self.runs, DataFrame):
|
46
|
+
return get_param_names(self.runs)
|
47
|
+
|
48
|
+
raise NotImplementedError
|
49
|
+
|
50
|
+
def get_param_dict(self) -> dict[str, list[str]]:
|
51
|
+
if isinstance(self.runs, DataFrame):
|
52
|
+
return get_param_dict(self.runs)
|
53
|
+
|
54
|
+
raise NotImplementedError
|
55
|
+
|
56
|
+
|
57
|
+
def filter_runs(runs: list[Run_] | DataFrame, config: object) -> list[Run_] | DataFrame:
|
58
|
+
if isinstance(runs, list):
|
59
|
+
return filter_runs_list(runs, config)
|
60
|
+
|
61
|
+
return filter_runs_dataframe(runs, config)
|
62
|
+
|
63
|
+
|
64
|
+
def _is_equal(run: Run_, key: str, value: Any) -> bool:
|
65
|
+
param = run.data.params.get(key, value)
|
66
|
+
|
67
|
+
if param is None:
|
68
|
+
return False
|
69
|
+
|
70
|
+
return type(value)(param) == value
|
71
|
+
|
72
|
+
|
73
|
+
def filter_runs_list(runs: list[Run_], config: object) -> list[Run_]:
|
74
|
+
for key, value in iter_params(config):
|
75
|
+
runs = [run for run in runs if _is_equal(run, key, value)]
|
76
|
+
|
77
|
+
return runs
|
78
|
+
|
79
|
+
|
80
|
+
def filter_runs_dataframe(runs: DataFrame, config: object) -> DataFrame:
|
81
|
+
index = np.ones(len(runs), dtype=bool)
|
82
|
+
|
83
|
+
for key, value in iter_params(config):
|
84
|
+
name = f"params.{key}"
|
85
|
+
|
86
|
+
if name in runs:
|
87
|
+
series = runs[name]
|
88
|
+
is_value = -series.isna()
|
89
|
+
param = series.fillna(value).astype(type(value))
|
90
|
+
index &= is_value & (param == value)
|
91
|
+
|
92
|
+
return runs[index]
|
93
|
+
|
94
|
+
|
95
|
+
def get_run(runs: list[Run_] | DataFrame, config: object) -> Run_ | Series:
|
96
|
+
runs = filter_runs(runs, config)
|
97
|
+
|
98
|
+
if len(runs) == 1:
|
99
|
+
return runs[0] if isinstance(runs, list) else runs.iloc[0]
|
100
|
+
|
101
|
+
msg = f"number of filtered runs is not 1: got {len(runs)}"
|
102
|
+
raise ValueError(msg)
|
103
|
+
|
104
|
+
|
105
|
+
def drop_unique_params(runs: DataFrame) -> DataFrame:
|
106
|
+
def select(column: str) -> bool:
|
107
|
+
return not column.startswith("params.") or len(runs[column].unique()) > 1
|
108
|
+
|
109
|
+
columns = [select(column) for column in runs.columns]
|
110
|
+
return runs.iloc[:, columns]
|
111
|
+
|
112
|
+
|
113
|
+
def get_param_names(runs: DataFrame) -> list[str]:
|
114
|
+
def get_name(column: str) -> str:
|
115
|
+
if column.startswith("params."):
|
116
|
+
return column.split(".", maxsplit=1)[-1]
|
117
|
+
|
118
|
+
return ""
|
119
|
+
|
120
|
+
columns = [get_name(column) for column in runs.columns]
|
121
|
+
return [column for column in columns if column]
|
122
|
+
|
123
|
+
|
124
|
+
def get_param_dict(runs: DataFrame) -> dict[str, list[str]]:
|
125
|
+
params = {}
|
126
|
+
for name in get_param_names(runs):
|
127
|
+
params[name] = list(runs[f"params.{name}"].unique())
|
128
|
+
|
129
|
+
return params
|
130
|
+
|
131
|
+
|
132
|
+
@dataclass
|
133
|
+
class Run:
|
134
|
+
run: Run_ | Series | str
|
135
|
+
|
136
|
+
def __repr__(self) -> str:
|
137
|
+
return f"{self.__class__.__name__}({self.run_id!r})"
|
138
|
+
|
139
|
+
@property
|
140
|
+
def run_id(self) -> str:
|
141
|
+
return get_run_id(self.run)
|
142
|
+
|
143
|
+
def artifact_uri(self, artifact_path: str | None = None) -> str:
|
144
|
+
return get_artifact_uri(self.run, artifact_path)
|
145
|
+
|
146
|
+
@property
|
147
|
+
def artifact_dir(self) -> Path:
|
148
|
+
return get_artifact_dir(self.run)
|
149
|
+
|
150
|
+
def artifact_path(self, artifact_path: str | None = None) -> Path:
|
151
|
+
return get_artifact_path(self.run, artifact_path)
|
152
|
+
|
153
|
+
@property
|
154
|
+
def config(self) -> DictConfig:
|
155
|
+
return load_config(self.run)
|
156
|
+
|
157
|
+
def log_hydra_output_dir(self) -> None:
|
158
|
+
log_hydra_output_dir(self.run)
|
159
|
+
|
160
|
+
|
161
|
+
def get_run_id(run: Run_ | Series | str) -> str:
|
162
|
+
if isinstance(run, str):
|
163
|
+
return run
|
164
|
+
|
165
|
+
if isinstance(run, Run_):
|
166
|
+
return run.info.run_id
|
167
|
+
|
168
|
+
return run.run_id
|
169
|
+
|
170
|
+
|
171
|
+
def get_artifact_uri(run: Run_ | Series | str, artifact_path: str | None = None) -> str:
|
172
|
+
run_id = get_run_id(run)
|
173
|
+
return artifact_utils.get_artifact_uri(run_id, artifact_path)
|
174
|
+
|
175
|
+
|
176
|
+
def get_artifact_dir(run: Run_ | Series | str) -> Path:
|
177
|
+
uri = get_artifact_uri(run)
|
178
|
+
return uri_to_path(uri)
|
179
|
+
|
180
|
+
|
181
|
+
def get_artifact_path(run: Run_ | Series | str, artifact_path: str | None = None) -> Path:
|
182
|
+
artifact_dir = get_artifact_dir(run)
|
183
|
+
return artifact_dir / artifact_path if artifact_path else artifact_dir
|
184
|
+
|
185
|
+
|
186
|
+
def load_config(run: Run_ | Series | str) -> DictConfig:
|
187
|
+
run_id = get_run_id(run)
|
188
|
+
return _load_config(run_id)
|
189
|
+
|
190
|
+
|
191
|
+
@cache
|
192
|
+
def _load_config(run_id: str) -> DictConfig:
|
193
|
+
try:
|
194
|
+
path = mlflow.artifacts.download_artifacts(
|
195
|
+
run_id=run_id,
|
196
|
+
artifact_path=".hydra/config.yaml",
|
197
|
+
)
|
198
|
+
except OSError:
|
199
|
+
return DictConfig({})
|
200
|
+
|
201
|
+
return OmegaConf.load(path) # type: ignore
|
202
|
+
|
203
|
+
|
204
|
+
def get_hydra_output_dir(run: Run_ | Series | str) -> Path:
|
205
|
+
path = get_artifact_dir(run) / ".hydra/hydra.yaml"
|
206
|
+
|
207
|
+
if path.exists():
|
208
|
+
hc = OmegaConf.load(path)
|
209
|
+
return Path(hc.hydra.runtime.output_dir)
|
210
|
+
|
211
|
+
raise FileNotFoundError
|
212
|
+
|
213
|
+
|
214
|
+
def log_hydra_output_dir(run: Run_ | Series | str) -> None:
|
215
|
+
output_dir = get_hydra_output_dir(run)
|
216
|
+
run_id = run if isinstance(run, str) else run.info.run_id
|
217
|
+
mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
|
File without changes
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import os
|
1
2
|
import sys
|
2
3
|
import time
|
3
4
|
from pathlib import Path
|
@@ -14,6 +15,9 @@ def run():
|
|
14
15
|
Path(f"{k}.txt").write_text(f"{k} {time.time()}")
|
15
16
|
time.sleep(1)
|
16
17
|
|
18
|
+
os.makedirs("a/b")
|
19
|
+
Path("a/b/c.txt").write_text(f"4 {time.time()}")
|
20
|
+
|
17
21
|
|
18
22
|
if __name__ == "__main__":
|
19
23
|
run()
|
@@ -12,12 +12,10 @@ from mlflow.entities.run import Run
|
|
12
12
|
|
13
13
|
@pytest.fixture
|
14
14
|
def runs(monkeypatch, tmp_path):
|
15
|
-
file = Path("tests/log_run.py").absolute()
|
15
|
+
file = Path("tests/scripts/log_run.py").absolute()
|
16
16
|
monkeypatch.chdir(tmp_path)
|
17
17
|
|
18
|
-
subprocess.check_call(
|
19
|
-
[sys.executable, file.as_posix(), "-m", "host=x,y", "port=1,2"]
|
20
|
-
)
|
18
|
+
subprocess.check_call([sys.executable, file.as_posix(), "-m", "host=x,y", "port=1,2"])
|
21
19
|
|
22
20
|
mlflow.set_experiment("log_run")
|
23
21
|
runs = mlflow.search_runs(output_format="list")
|
@@ -48,7 +46,7 @@ def read_log(run_id: str) -> str:
|
|
48
46
|
|
49
47
|
|
50
48
|
def test_load_config(run_id: str):
|
51
|
-
from hydraflow.
|
49
|
+
from hydraflow.runs import load_config
|
52
50
|
|
53
51
|
log = read_log(run_id)
|
54
52
|
host, port = log.splitlines()[0].split("START,")[-1].split(",")
|
@@ -56,9 +54,3 @@ def test_load_config(run_id: str):
|
|
56
54
|
cfg = load_config(run_id)
|
57
55
|
assert cfg.host == host.strip()
|
58
56
|
assert cfg.port == int(port)
|
59
|
-
|
60
|
-
|
61
|
-
def test_load_config_err(run_id: str):
|
62
|
-
from hydraflow.run import load_config
|
63
|
-
|
64
|
-
assert not load_config(run_id, "a")
|
@@ -45,47 +45,47 @@ def runs_df(_runs: tuple[list[Run], DataFrame], request: pytest.FixtureRequest):
|
|
45
45
|
return _runs[1]
|
46
46
|
|
47
47
|
|
48
|
-
def
|
49
|
-
from hydraflow.
|
48
|
+
def test_filter_one(runs: list[Run] | DataFrame):
|
49
|
+
from hydraflow.runs import filter_runs
|
50
50
|
|
51
51
|
assert len(runs) == 6
|
52
|
-
x =
|
52
|
+
x = filter_runs(runs, {"p": 1})
|
53
53
|
assert len(x) == 1
|
54
54
|
|
55
55
|
|
56
|
-
def
|
57
|
-
from hydraflow.
|
56
|
+
def test_filter_all(runs: list[Run] | DataFrame):
|
57
|
+
from hydraflow.runs import filter_runs
|
58
58
|
|
59
59
|
assert len(runs) == 6
|
60
|
-
x =
|
60
|
+
x = filter_runs(runs, {"q": 0})
|
61
61
|
assert len(x) == 6
|
62
62
|
|
63
63
|
|
64
|
-
def
|
65
|
-
from hydraflow.
|
64
|
+
def test_get_list(runs_list: list[Run]):
|
65
|
+
from hydraflow.runs import get_run
|
66
66
|
|
67
|
-
run =
|
67
|
+
run = get_run(runs_list, {"p": 4})
|
68
68
|
assert isinstance(run, Run)
|
69
69
|
assert run.data.params["p"] == "4"
|
70
70
|
|
71
71
|
|
72
|
-
def
|
73
|
-
from hydraflow.
|
72
|
+
def test_get_df(runs_df: DataFrame):
|
73
|
+
from hydraflow.runs import get_run
|
74
74
|
|
75
|
-
run =
|
75
|
+
run = get_run(runs_df, {"p": 2})
|
76
76
|
assert isinstance(run, Series)
|
77
77
|
assert run["params.p"] == "2"
|
78
78
|
|
79
79
|
|
80
|
-
def
|
81
|
-
from hydraflow.
|
80
|
+
def test_get_error(runs: list[Run] | DataFrame):
|
81
|
+
from hydraflow.runs import get_run
|
82
82
|
|
83
83
|
with pytest.raises(ValueError):
|
84
|
-
|
84
|
+
get_run(runs, {"q": 0})
|
85
85
|
|
86
86
|
|
87
87
|
def test_drop_unique_params(runs_df):
|
88
|
-
from hydraflow.
|
88
|
+
from hydraflow.runs import drop_unique_params
|
89
89
|
|
90
90
|
assert "params.p" in runs_df
|
91
91
|
assert "params.q" in runs_df
|
@@ -95,7 +95,7 @@ def test_drop_unique_params(runs_df):
|
|
95
95
|
|
96
96
|
|
97
97
|
def test_get_param_names(runs_df: DataFrame):
|
98
|
-
from hydraflow.
|
98
|
+
from hydraflow.runs import get_param_names
|
99
99
|
|
100
100
|
params = get_param_names(runs_df)
|
101
101
|
assert len(params) == 2
|
@@ -104,7 +104,7 @@ def test_get_param_names(runs_df: DataFrame):
|
|
104
104
|
|
105
105
|
|
106
106
|
def test_get_param_dict(runs_df: DataFrame):
|
107
|
-
from hydraflow.
|
107
|
+
from hydraflow.runs import get_param_dict
|
108
108
|
|
109
109
|
params = get_param_dict(runs_df)
|
110
110
|
assert len(params["p"]) == 6
|
@@ -113,7 +113,7 @@ def test_get_param_dict(runs_df: DataFrame):
|
|
113
113
|
|
114
114
|
@pytest.mark.parametrize("i", range(6))
|
115
115
|
def test_get_run_id(i: int, runs_list: list[Run], runs_df: DataFrame):
|
116
|
-
from hydraflow.
|
116
|
+
from hydraflow.runs import get_run_id
|
117
117
|
|
118
118
|
assert get_run_id(runs_list[i]) == get_run_id(runs_df.iloc[i])
|
119
119
|
assert get_run_id(runs_list[i]) == get_run_id(runs_df.iloc[i])
|
@@ -125,7 +125,7 @@ def test_get_run_id(i: int, runs_list: list[Run], runs_df: DataFrame):
|
|
125
125
|
@pytest.mark.parametrize("i", range(6))
|
126
126
|
@pytest.mark.parametrize("path", [None, "a"])
|
127
127
|
def test_get_artifact_uri(i: int, path, runs_list: list[Run], runs_df: DataFrame):
|
128
|
-
from hydraflow.
|
128
|
+
from hydraflow.runs import get_artifact_uri, get_run_id
|
129
129
|
|
130
130
|
x = get_run_id(runs_list[i])
|
131
131
|
y = get_artifact_uri(runs_list[i], path)
|
@@ -144,7 +144,7 @@ def test_chdir_artifact_list(i: int, runs_list: list[Run]):
|
|
144
144
|
|
145
145
|
|
146
146
|
def test_hydra_output_dir_error(runs_list: list[Run]):
|
147
|
-
from hydraflow.
|
147
|
+
from hydraflow.runs import get_hydra_output_dir
|
148
148
|
|
149
149
|
with pytest.raises(FileNotFoundError):
|
150
150
|
get_hydra_output_dir(runs_list[0])
|
@@ -163,7 +163,7 @@ def df():
|
|
163
163
|
|
164
164
|
|
165
165
|
def test_unique_params(df):
|
166
|
-
from hydraflow.
|
166
|
+
from hydraflow.runs import drop_unique_params
|
167
167
|
|
168
168
|
df = drop_unique_params(df)
|
169
169
|
assert len(df.columns) == 3
|
@@ -173,16 +173,88 @@ def test_unique_params(df):
|
|
173
173
|
|
174
174
|
|
175
175
|
def test_param_names(df):
|
176
|
-
from hydraflow.
|
176
|
+
from hydraflow.runs import get_param_names
|
177
177
|
|
178
178
|
names = get_param_names(df)
|
179
179
|
assert names == ["x", "y", "z"]
|
180
180
|
|
181
181
|
|
182
182
|
def test_param_dict(df):
|
183
|
-
from hydraflow.
|
183
|
+
from hydraflow.runs import get_param_dict
|
184
184
|
|
185
185
|
x = get_param_dict(df)
|
186
186
|
assert x["x"] == [1, 2]
|
187
187
|
assert x["y"] == [1, 2]
|
188
188
|
assert x["z"] == [1]
|
189
|
+
|
190
|
+
|
191
|
+
def test_runs_repr(runs):
|
192
|
+
from hydraflow.runs import Runs
|
193
|
+
|
194
|
+
assert repr(Runs(runs)) == "Runs(6)"
|
195
|
+
|
196
|
+
|
197
|
+
def test_runs_filter(runs):
|
198
|
+
from hydraflow.runs import Runs
|
199
|
+
|
200
|
+
runs = Runs(runs)
|
201
|
+
|
202
|
+
assert len(runs.filter({})) == 6
|
203
|
+
assert len(runs.filter({"p": 1})) == 1
|
204
|
+
assert len(runs.filter({"q": 0})) == 6
|
205
|
+
assert len(runs.filter({"q": -1})) == 0
|
206
|
+
|
207
|
+
|
208
|
+
def test_runs_get(runs):
|
209
|
+
from hydraflow.runs import Run, Runs
|
210
|
+
|
211
|
+
runs = Runs(runs)
|
212
|
+
run = runs.get({"p": 4})
|
213
|
+
assert isinstance(run, Run)
|
214
|
+
|
215
|
+
|
216
|
+
def test_runs_drop_unique_params(runs_df):
|
217
|
+
from hydraflow.runs import Runs
|
218
|
+
|
219
|
+
runs = Runs(runs_df)
|
220
|
+
assert runs.runs.shape == (6, 12) # type: ignore
|
221
|
+
runs = runs.drop_unique_params()
|
222
|
+
assert runs.runs.shape == (6, 11) # type: ignore
|
223
|
+
|
224
|
+
|
225
|
+
def test_runs_get_params_names(runs_df):
|
226
|
+
from hydraflow.runs import Runs
|
227
|
+
|
228
|
+
runs = Runs(runs_df)
|
229
|
+
names = runs.get_param_names()
|
230
|
+
assert len(names) == 2
|
231
|
+
assert "p" in names
|
232
|
+
assert "q" in names
|
233
|
+
|
234
|
+
|
235
|
+
def test_runs_get_params_dict(runs_df):
|
236
|
+
from hydraflow.runs import Runs
|
237
|
+
|
238
|
+
runs = Runs(runs_df)
|
239
|
+
params = runs.get_param_dict()
|
240
|
+
assert params["p"] == ["0", "1", "2", "3", "4", "5"]
|
241
|
+
assert params["q"] == ["0"]
|
242
|
+
|
243
|
+
|
244
|
+
@pytest.fixture
|
245
|
+
def run(runs):
|
246
|
+
from hydraflow.runs import Runs
|
247
|
+
|
248
|
+
return Runs(runs).get({"p": 5})
|
249
|
+
|
250
|
+
|
251
|
+
def test_run_id(run):
|
252
|
+
assert run.run_id in repr(run)
|
253
|
+
|
254
|
+
|
255
|
+
def test_run_artifact_uri(run):
|
256
|
+
assert run.artifact_uri().startswith("file:")
|
257
|
+
|
258
|
+
|
259
|
+
def test_run_artifact_dir(run):
|
260
|
+
assert run.artifact_dir.exists()
|
@@ -1,172 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from pathlib import Path
|
4
|
-
from typing import TYPE_CHECKING, Any, overload
|
5
|
-
|
6
|
-
import mlflow
|
7
|
-
import numpy as np
|
8
|
-
from mlflow.entities.run import Run
|
9
|
-
from mlflow.tracking import artifact_utils
|
10
|
-
from omegaconf import DictConfig, OmegaConf
|
11
|
-
|
12
|
-
from hydraflow.config import iter_params
|
13
|
-
from hydraflow.util import uri_to_path
|
14
|
-
|
15
|
-
if TYPE_CHECKING:
|
16
|
-
from typing import Any
|
17
|
-
|
18
|
-
from pandas import DataFrame, Series
|
19
|
-
|
20
|
-
|
21
|
-
@overload
|
22
|
-
def filter_by_config(runs: list[Run], config: object) -> list[Run]: ...
|
23
|
-
|
24
|
-
|
25
|
-
@overload
|
26
|
-
def filter_by_config(runs: DataFrame, config: object) -> DataFrame: ...
|
27
|
-
|
28
|
-
|
29
|
-
def filter_by_config(runs: list[Run] | DataFrame, config: object):
|
30
|
-
if isinstance(runs, list):
|
31
|
-
return filter_by_config_list(runs, config)
|
32
|
-
|
33
|
-
return filter_by_config_dataframe(runs, config)
|
34
|
-
|
35
|
-
|
36
|
-
def _is_equal(run: Run, key: str, value: Any) -> bool:
|
37
|
-
param = run.data.params.get(key, value)
|
38
|
-
if param is None:
|
39
|
-
return False
|
40
|
-
|
41
|
-
return type(value)(param) == value
|
42
|
-
|
43
|
-
|
44
|
-
def filter_by_config_list(runs: list[Run], config: object) -> list[Run]:
|
45
|
-
for key, value in iter_params(config):
|
46
|
-
runs = [run for run in runs if _is_equal(run, key, value)]
|
47
|
-
|
48
|
-
return runs
|
49
|
-
|
50
|
-
|
51
|
-
def filter_by_config_dataframe(runs: DataFrame, config: object) -> DataFrame:
|
52
|
-
index = np.ones(len(runs), dtype=bool)
|
53
|
-
|
54
|
-
for key, value in iter_params(config):
|
55
|
-
name = f"params.{key}"
|
56
|
-
if name in runs:
|
57
|
-
series = runs[name]
|
58
|
-
is_value = -series.isna()
|
59
|
-
param = series.fillna(value).astype(type(value))
|
60
|
-
index &= is_value & (param == value)
|
61
|
-
|
62
|
-
return runs[index]
|
63
|
-
|
64
|
-
|
65
|
-
@overload
|
66
|
-
def get_by_config(runs: list[Run], config: object) -> Run: ...
|
67
|
-
|
68
|
-
|
69
|
-
@overload
|
70
|
-
def get_by_config(runs: DataFrame, config: object) -> Series: ...
|
71
|
-
|
72
|
-
|
73
|
-
def get_by_config(runs: list[Run] | DataFrame, config: object):
|
74
|
-
runs = filter_by_config(runs, config)
|
75
|
-
|
76
|
-
if len(runs) == 1:
|
77
|
-
return runs[0] if isinstance(runs, list) else runs.iloc[0]
|
78
|
-
|
79
|
-
msg = f"filtered runs has not length of 1.: {len(runs)}"
|
80
|
-
raise ValueError(msg)
|
81
|
-
|
82
|
-
|
83
|
-
def drop_unique_params(runs: DataFrame) -> DataFrame:
|
84
|
-
def select(column: str) -> bool:
|
85
|
-
return not column.startswith("params.") or len(runs[column].unique()) > 1
|
86
|
-
|
87
|
-
columns = [select(column) for column in runs.columns]
|
88
|
-
return runs.iloc[:, columns]
|
89
|
-
|
90
|
-
|
91
|
-
def get_param_names(runs: DataFrame) -> list[str]:
|
92
|
-
def get_name(column: str) -> str:
|
93
|
-
if column.startswith("params."):
|
94
|
-
return column.split(".", maxsplit=1)[-1]
|
95
|
-
|
96
|
-
return ""
|
97
|
-
|
98
|
-
columns = [get_name(column) for column in runs.columns]
|
99
|
-
return [column for column in columns if column]
|
100
|
-
|
101
|
-
|
102
|
-
def get_param_dict(runs: DataFrame) -> dict[str, list[str]]:
|
103
|
-
params = {}
|
104
|
-
for name in get_param_names(runs):
|
105
|
-
params[name] = list(runs[f"params.{name}"].unique())
|
106
|
-
|
107
|
-
return params
|
108
|
-
|
109
|
-
|
110
|
-
def get_run_id(run: Run | Series | str) -> str:
|
111
|
-
if isinstance(run, Run):
|
112
|
-
return run.info.run_id
|
113
|
-
if isinstance(run, str):
|
114
|
-
return run
|
115
|
-
return run.run_id
|
116
|
-
|
117
|
-
|
118
|
-
def get_artifact_uri(run: Run | Series | str, artifact_path: str | None = None) -> str:
|
119
|
-
if isinstance(run, Run):
|
120
|
-
uri = run.info.artifact_uri
|
121
|
-
elif isinstance(run, str):
|
122
|
-
uri = artifact_utils.get_artifact_uri(run_id=run)
|
123
|
-
else:
|
124
|
-
uri = run.artifact_uri
|
125
|
-
|
126
|
-
if artifact_path:
|
127
|
-
uri = f"{uri}/{artifact_path}"
|
128
|
-
|
129
|
-
return uri # type: ignore
|
130
|
-
|
131
|
-
|
132
|
-
def get_artifact_dir(run: Run | Series | str) -> Path:
|
133
|
-
uri = get_artifact_uri(run)
|
134
|
-
return uri_to_path(uri)
|
135
|
-
|
136
|
-
|
137
|
-
def get_artifact_path(
|
138
|
-
run: Run | Series | str,
|
139
|
-
artifact_path: str | None = None,
|
140
|
-
) -> Path:
|
141
|
-
artifact_dir = get_artifact_dir(run)
|
142
|
-
return artifact_dir / artifact_path if artifact_path else artifact_dir
|
143
|
-
|
144
|
-
|
145
|
-
def load_config(run: Run | Series | str, output_subdir: str = ".hydra") -> DictConfig:
|
146
|
-
run_id = get_run_id(run)
|
147
|
-
|
148
|
-
try:
|
149
|
-
path = mlflow.artifacts.download_artifacts(
|
150
|
-
run_id=run_id,
|
151
|
-
artifact_path=f"{output_subdir}/config.yaml",
|
152
|
-
)
|
153
|
-
except OSError:
|
154
|
-
return DictConfig({})
|
155
|
-
|
156
|
-
return OmegaConf.load(path) # type: ignore
|
157
|
-
|
158
|
-
|
159
|
-
def get_hydra_output_dir(run: Run | Series | str) -> Path:
|
160
|
-
path = get_artifact_dir(run) / ".hydra/hydra.yaml"
|
161
|
-
|
162
|
-
if path.exists():
|
163
|
-
hc = OmegaConf.load(path)
|
164
|
-
return Path(hc.hydra.runtime.output_dir)
|
165
|
-
|
166
|
-
raise FileNotFoundError
|
167
|
-
|
168
|
-
|
169
|
-
def log_hydra_output_dir(run: Run | Series | str) -> None:
|
170
|
-
output_dir = get_hydra_output_dir(run)
|
171
|
-
run_id = run if isinstance(run, str) else run.info.run_id
|
172
|
-
mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|