hydraflow 0.1.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,16 @@
1
+ {
2
+ "name": "hydraflow",
3
+ "image": "mcr.microsoft.com/devcontainers/python:3.10-bookworm",
4
+ "features": {
5
+ "ghcr.io/devcontainers-contrib/features/ruff:1": {},
6
+ "ghcr.io/devcontainers-contrib/features/hatch:2": {},
7
+ "ghcr.io/devcontainers-contrib/features/starship:1": {},
8
+ "ghcr.io/va-h/devcontainers-features/uv:1": {}
9
+ },
10
+ "customizations": {
11
+ "vscode": {
12
+ "extensions": ["charliermarsh.ruff"]
13
+ }
14
+ },
15
+ "postCreateCommand": ".devcontainer/postCreate.sh"
16
+ }
@@ -0,0 +1,5 @@
1
+ #!/bin/sh
2
+
3
+ echo 'eval "$(starship init bash)"' >> ~/.bashrc
4
+ mkdir -p ~/.config
5
+ cp .devcontainer/starship.toml ~/.config
@@ -0,0 +1,29 @@
1
+ "$schema" = 'https://starship.rs/config-schema.json'
2
+
3
+ add_newline = false
4
+
5
+ [username]
6
+ disabled = true
7
+
8
+ [hostname]
9
+ disabled = true
10
+
11
+ [package]
12
+ format = '[$symbol$version]($style) '
13
+ symbol = "󰏗 "
14
+ style = 'bold blue'
15
+
16
+ [git_branch]
17
+ format = '[$symbol$branch(:$remote_branch)]($style)'
18
+ symbol = ' '
19
+ style = 'bold green'
20
+
21
+ [git_status]
22
+ style = 'red'
23
+ modified = '*'
24
+ format = '([$modified]($style)) '
25
+
26
+ [python]
27
+ format = '[${symbol}${pyenv_prefix}(${version} )(\($virtualenv\) )]($style)'
28
+ symbol = ' '
29
+ style = 'yellow'
@@ -0,0 +1,4 @@
1
+ .coverage
2
+ .venv/
3
+ __pycache__/
4
+ lcov.info
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Daizu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,29 @@
1
+ Metadata-Version: 2.3
2
+ Name: hydraflow
3
+ Version: 0.1.0
4
+ Summary: Hydra with MLflow
5
+ Project-URL: Source, https://github.com/daizutabi/hydraflow
6
+ Project-URL: Issues, https://github.com/daizutabi/hydraflow/issues
7
+ Author-email: daizutabi <daizutabi@gmail.com>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Documentation
16
+ Classifier: Topic :: Software Development :: Documentation
17
+ Requires-Python: >=3.10
18
+ Requires-Dist: hydra-core
19
+ Requires-Dist: mlflow
20
+ Requires-Dist: watchdog
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest-clarity; extra == 'dev'
23
+ Requires-Dist: pytest-cov; extra == 'dev'
24
+ Requires-Dist: pytest-randomly; extra == 'dev'
25
+ Requires-Dist: pytest-xdist; extra == 'dev'
26
+ Requires-Dist: setuptools; extra == 'dev'
27
+ Description-Content-Type: text/markdown
28
+
29
+ # hydraflow
@@ -0,0 +1 @@
1
+ # hydraflow
@@ -0,0 +1,67 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "hydraflow"
7
+ version = "0.1.0"
8
+ description = "Hydra with MLflow"
9
+ readme = "README.md"
10
+ license = "MIT"
11
+ authors = [{ name = "daizutabi", email = "daizutabi@gmail.com" }]
12
+ classifiers = [
13
+ "Development Status :: 4 - Beta",
14
+ "Programming Language :: Python",
15
+ "Programming Language :: Python :: 3.10",
16
+ "Programming Language :: Python :: 3.11",
17
+ "Programming Language :: Python :: 3.12",
18
+ "Topic :: Documentation",
19
+ "Topic :: Software Development :: Documentation",
20
+ ]
21
+ requires-python = ">=3.10"
22
+ dependencies = ["hydra-core", "mlflow", "watchdog"]
23
+
24
+ [project.optional-dependencies]
25
+ dev = [
26
+ "pytest-clarity",
27
+ "pytest-cov",
28
+ "pytest-randomly",
29
+ "pytest-xdist",
30
+ "setuptools",
31
+ ]
32
+
33
+ [project.urls]
34
+ # Documentation = "https://daizutabi.github.io/hydraflow/"
35
+ Source = "https://github.com/daizutabi/hydraflow"
36
+ Issues = "https://github.com/daizutabi/hydraflow/issues"
37
+
38
+ [tool.hatch.build.targets.sdist]
39
+ exclude = ["/.github", "/docs"]
40
+
41
+ [tool.hatch.build.targets.wheel]
42
+ packages = ["src/hydraflow"]
43
+
44
+ [tool.pytest.ini_options]
45
+ addopts = [
46
+ "--doctest-modules",
47
+ "--cov=hydraflow",
48
+ "--cov-report=lcov:lcov.info",
49
+ ]
50
+
51
+ doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"]
52
+ filterwarnings = ['ignore:pkg_resources is deprecated:DeprecationWarning']
53
+
54
+ [tool.coverage.report]
55
+ exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
56
+
57
+ [tool.hatch.envs.docs.scripts]
58
+ build = "mkdocs build --clean --strict {args}"
59
+ serve = "mkdocs serve --dev-addr localhost:8000 {args}"
60
+ deploy = "mkdocs gh-deploy --force"
61
+
62
+ [tool.ruff]
63
+ line-length = 88
64
+ target-version = "py312"
65
+
66
+ [tool.ruff.lint]
67
+ unfixable = ["F401", "RUF100"]
@@ -0,0 +1,28 @@
1
+ from .context import Info, chdir_artifact, log_run, watch
2
+ from .mlflow import set_experiment
3
+ from .run import (
4
+ filter_by_config,
5
+ get_artifact_dir,
6
+ get_artifact_path,
7
+ get_artifact_uri,
8
+ get_by_config,
9
+ get_param_dict,
10
+ get_param_names,
11
+ get_run_id,
12
+ )
13
+
14
+ __all__ = [
15
+ "Info",
16
+ "chdir_artifact",
17
+ "filter_by_config",
18
+ "get_artifact_dir",
19
+ "get_artifact_path",
20
+ "get_artifact_uri",
21
+ "get_by_config",
22
+ "get_param_dict",
23
+ "get_param_names",
24
+ "get_run_id",
25
+ "log_run",
26
+ "set_experiment",
27
+ "watch",
28
+ ]
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from omegaconf import DictConfig, ListConfig, OmegaConf
6
+
7
+ if TYPE_CHECKING:
8
+ from collections.abc import Iterator
9
+ from typing import Any
10
+
11
+
12
+ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
13
+ if not isinstance(config, DictConfig | ListConfig):
14
+ config = OmegaConf.create(config) # type: ignore
15
+
16
+ if isinstance(config, DictConfig):
17
+ for key, value in config.items():
18
+ if isinstance(value, (DictConfig, ListConfig)):
19
+ yield from iter_params(value, f"{prefix}{key}.")
20
+
21
+ else:
22
+ yield f"{prefix}{key}", value
23
+
24
+ elif isinstance(config, ListConfig):
25
+ for index, value in enumerate(config):
26
+ if isinstance(value, (DictConfig, ListConfig)):
27
+ yield from iter_params(value, f"{prefix}{index}.")
28
+
29
+ else:
30
+ yield f"{prefix}{index}", value
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ import mlflow
11
+ from hydra.core.hydra_config import HydraConfig
12
+ from watchdog.events import FileModifiedEvent, FileSystemEventHandler
13
+ from watchdog.observers import Observer
14
+
15
+ from hydraflow.mlflow import log_params
16
+ from hydraflow.run import get_artifact_path
17
+ from hydraflow.util import uri_to_path
18
+
19
+ if TYPE_CHECKING:
20
+ from collections.abc import Callable, Iterator
21
+
22
+ from mlflow.entities.run import Run
23
+ from pandas import Series
24
+
25
+
26
+ @dataclass
27
+ class Info:
28
+ output_dir: Path
29
+ artifact_dir: Path
30
+
31
+
32
+ @contextmanager
33
+ def log_run(
34
+ config: object,
35
+ *,
36
+ synchronous: bool | None = None,
37
+ ) -> Iterator[Info]:
38
+ log_params(config, synchronous=synchronous)
39
+
40
+ hc = HydraConfig.get()
41
+ output_dir = Path(hc.runtime.output_dir)
42
+ uri = mlflow.get_artifact_uri()
43
+ location = Info(output_dir, uri_to_path(uri))
44
+
45
+ # Save '.hydra' config directory first.
46
+ output_subdir = output_dir / (hc.output_subdir or "")
47
+ mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
48
+
49
+ try:
50
+ yield location
51
+
52
+ finally:
53
+ # Save output_dir including '.hydra' config directory.
54
+ mlflow.log_artifacts(output_dir.as_posix())
55
+
56
+
57
+ @contextmanager
58
+ def watch(
59
+ func: Callable[[Path], None],
60
+ dir: Path | str = "",
61
+ timeout: int = 600,
62
+ ) -> Iterator[None]:
63
+ if not dir:
64
+ uri = mlflow.get_artifact_uri()
65
+ dir = uri_to_path(uri)
66
+
67
+ handler = Handler(func)
68
+ observer = Observer()
69
+ observer.schedule(handler, dir, recursive=True)
70
+ observer.start()
71
+
72
+ try:
73
+ yield
74
+
75
+ finally:
76
+ elapsed = 0
77
+ while not observer.event_queue.empty():
78
+ time.sleep(0.2)
79
+ elapsed += 0.2
80
+ if elapsed > timeout:
81
+ break
82
+
83
+ observer.stop()
84
+ observer.join()
85
+
86
+
87
+ class Handler(FileSystemEventHandler):
88
+ def __init__(self, func: Callable[[Path], None]) -> None:
89
+ self.func = func
90
+
91
+ def on_modified(self, event: FileModifiedEvent) -> None:
92
+ file = Path(event.src_path)
93
+ if file.is_file():
94
+ self.func(file)
95
+
96
+
97
+ @contextmanager
98
+ def chdir_artifact(
99
+ run: Run | Series | str,
100
+ artifact_path: str | None = None,
101
+ ) -> Iterator[Path]:
102
+ curdir = Path.cwd()
103
+
104
+ artifact_dir = get_artifact_path(run, artifact_path)
105
+
106
+ os.chdir(artifact_dir)
107
+ try:
108
+ yield artifact_dir
109
+ finally:
110
+ os.chdir(curdir)
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ import mlflow
4
+ from hydra.core.hydra_config import HydraConfig
5
+
6
+ from hydraflow.config import iter_params
7
+
8
+
9
+ def set_experiment() -> None:
10
+ hc = HydraConfig.get()
11
+ mlflow.set_tracking_uri("")
12
+ mlflow.set_experiment(hc.job.name)
13
+
14
+
15
+ def log_params(config: object, *, synchronous: bool | None = None) -> None:
16
+ for key, value in iter_params(config):
17
+ mlflow.log_param(key, value, synchronous=synchronous)
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, overload
5
+
6
+ import mlflow
7
+ import numpy as np
8
+ from mlflow.entities.run import Run
9
+ from mlflow.tracking import artifact_utils
10
+ from omegaconf import DictConfig, OmegaConf
11
+
12
+ from hydraflow.config import iter_params
13
+ from hydraflow.util import uri_to_path
14
+
15
+ if TYPE_CHECKING:
16
+ from typing import Any
17
+
18
+ from pandas import DataFrame, Series
19
+
20
+
21
+ @overload
22
+ def filter_by_config(runs: list[Run], config: object) -> list[Run]: ...
23
+
24
+
25
+ @overload
26
+ def filter_by_config(runs: DataFrame, config: object) -> DataFrame: ...
27
+
28
+
29
+ def filter_by_config(runs: list[Run] | DataFrame, config: object):
30
+ if isinstance(runs, list):
31
+ return filter_by_config_list(runs, config)
32
+
33
+ return filter_by_config_dataframe(runs, config)
34
+
35
+
36
+ def _is_equal(run: Run, key: str, value: Any) -> bool:
37
+ param = run.data.params.get(key, value)
38
+ if param is None:
39
+ return False
40
+
41
+ return type(value)(param) == value
42
+
43
+
44
+ def filter_by_config_list(runs: list[Run], config: object) -> list[Run]:
45
+ for key, value in iter_params(config):
46
+ runs = [run for run in runs if _is_equal(run, key, value)]
47
+
48
+ return runs
49
+
50
+
51
+ def filter_by_config_dataframe(runs: DataFrame, config: object) -> DataFrame:
52
+ index = np.ones(len(runs), dtype=bool)
53
+
54
+ for key, value in iter_params(config):
55
+ name = f"params.{key}"
56
+ if name in runs:
57
+ series = runs[name]
58
+ is_value = -series.isna()
59
+ param = series.fillna(value).astype(type(value))
60
+ index &= is_value & (param == value)
61
+
62
+ return runs[index]
63
+
64
+
65
+ @overload
66
+ def get_by_config(runs: list[Run], config: object) -> Run: ...
67
+
68
+
69
+ @overload
70
+ def get_by_config(runs: DataFrame, config: object) -> Series: ...
71
+
72
+
73
+ def get_by_config(runs: list[Run] | DataFrame, config: object):
74
+ runs = filter_by_config(runs, config)
75
+
76
+ if len(runs) == 1:
77
+ return runs[0] if isinstance(runs, list) else runs.iloc[0]
78
+
79
+ msg = f"filtered runs has not length of 1.: {len(runs)}"
80
+ raise ValueError(msg)
81
+
82
+
83
+ def drop_unique_params(runs: DataFrame) -> DataFrame:
84
+ def select(column: str) -> bool:
85
+ return not column.startswith("params.") or len(runs[column].unique()) > 1
86
+
87
+ columns = [select(column) for column in runs.columns]
88
+ return runs.iloc[:, columns]
89
+
90
+
91
+ def get_param_names(runs: DataFrame) -> list[str]:
92
+ def get_name(column: str) -> str:
93
+ if column.startswith("params."):
94
+ return column.split(".", maxsplit=1)[-1]
95
+
96
+ return ""
97
+
98
+ columns = [get_name(column) for column in runs.columns]
99
+ return [column for column in columns if column]
100
+
101
+
102
+ def get_param_dict(runs: DataFrame) -> dict[str, list[str]]:
103
+ params = {}
104
+ for name in get_param_names(runs):
105
+ params[name] = list(runs[f"params.{name}"].unique())
106
+
107
+ return params
108
+
109
+
110
+ def get_run_id(run: Run | Series | str) -> str:
111
+ if isinstance(run, Run):
112
+ return run.info.run_id
113
+ if isinstance(run, str):
114
+ return run
115
+ return run.run_id
116
+
117
+
118
+ def get_artifact_uri(run: Run | Series | str, artifact_path: str | None = None) -> str:
119
+ if isinstance(run, Run):
120
+ uri = run.info.artifact_uri
121
+ elif isinstance(run, str):
122
+ uri = artifact_utils.get_artifact_uri(run_id=run)
123
+ else:
124
+ uri = run.artifact_uri
125
+
126
+ if artifact_path:
127
+ uri = f"{uri}/{artifact_path}"
128
+
129
+ return uri # type: ignore
130
+
131
+
132
+ def get_artifact_dir(run: Run | Series | str) -> Path:
133
+ uri = get_artifact_uri(run)
134
+ return uri_to_path(uri)
135
+
136
+
137
+ def get_artifact_path(
138
+ run: Run | Series | str,
139
+ artifact_path: str | None = None,
140
+ ) -> Path:
141
+ artifact_dir = get_artifact_dir(run)
142
+ return artifact_dir / artifact_path if artifact_path else artifact_dir
143
+
144
+
145
+ def load_config(run: Run | Series | str, output_subdir: str = ".hydra") -> DictConfig:
146
+ run_id = get_run_id(run)
147
+
148
+ try:
149
+ path = mlflow.artifacts.download_artifacts(
150
+ run_id=run_id,
151
+ artifact_path=f"{output_subdir}/config.yaml",
152
+ )
153
+ except OSError:
154
+ return DictConfig({})
155
+
156
+ return OmegaConf.load(path) # type: ignore
157
+
158
+
159
+ def get_hydra_output_dir(run: Run | Series | str) -> Path:
160
+ path = get_artifact_dir(run) / ".hydra/hydra.yaml"
161
+
162
+ if path.exists():
163
+ hc = OmegaConf.load(path)
164
+ return Path(hc.hydra.runtime.output_dir)
165
+
166
+ raise FileNotFoundError
167
+
168
+
169
+ def log_hydra_output_dir(run: Run | Series | str) -> None:
170
+ output_dir = get_hydra_output_dir(run)
171
+ run_id = run if isinstance(run, str) else run.info.run_id
172
+ mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
@@ -0,0 +1,11 @@
1
+ import platform
2
+ from pathlib import Path
3
+ from urllib.parse import urlparse
4
+
5
+
6
+ def uri_to_path(uri: str) -> Path:
7
+ path = urlparse(uri).path
8
+ if platform.system() == "Windows" and path.startswith("/"):
9
+ path = path[1:]
10
+
11
+ return Path(path)
@@ -0,0 +1,37 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+
6
+ import hydra
7
+ import mlflow
8
+ from hydra.core.config_store import ConfigStore
9
+
10
+ from hydraflow.context import log_run
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class MySQLConfig:
17
+ host: str = "localhost"
18
+ port: int = 3306
19
+
20
+
21
+ cs = ConfigStore.instance()
22
+ cs.store(name="config", node=MySQLConfig)
23
+
24
+
25
+ @hydra.main(version_base=None, config_name="config")
26
+ def app(cfg: MySQLConfig):
27
+ mlflow.set_experiment("log_run")
28
+ with mlflow.start_run(), log_run(cfg) as info:
29
+ log.info(f"START, {cfg.host}, {cfg.port} ")
30
+ mlflow.log_text(info.artifact_dir.as_posix(), "artifact_dir.txt")
31
+ mlflow.log_text(info.output_dir.as_posix(), "output_dir.txt")
32
+ (info.artifact_dir / "a.txt").write_text("abc")
33
+ log.info("END")
34
+
35
+
36
+ if __name__ == "__main__":
37
+ app()
@@ -0,0 +1,63 @@
1
+ from dataclasses import dataclass, field
2
+
3
+ import pytest
4
+ from omegaconf import OmegaConf
5
+
6
+
7
+ def test_iter_params():
8
+ from hydraflow.config import iter_params
9
+
10
+ conf = OmegaConf.create({"k": "v", "l": [1, {"a": "1", "b": "2", 3: "c"}]})
11
+ it = iter_params(conf)
12
+ assert next(it) == ("k", "v")
13
+ assert next(it) == ("l.0", 1)
14
+ assert next(it) == ("l.1.a", "1")
15
+ assert next(it) == ("l.1.b", "2")
16
+ assert next(it) == ("l.1.3", "c")
17
+
18
+
19
+ @dataclass
20
+ class Size:
21
+ x: int = 1
22
+ y: int = 2
23
+
24
+
25
+ @dataclass
26
+ class Db:
27
+ name: str = "name"
28
+ port: int = 100
29
+
30
+
31
+ @dataclass
32
+ class Store:
33
+ items: list[str] = field(default_factory=lambda: ["a", "b"])
34
+
35
+
36
+ @dataclass
37
+ class Config:
38
+ size: Size = field(default_factory=Size)
39
+ db: Db = field(default_factory=Db)
40
+ store: Store = field(default_factory=Store)
41
+
42
+
43
+ @pytest.fixture
44
+ def cfg():
45
+ return Config()
46
+
47
+
48
+ def test_config(cfg: Config):
49
+ assert cfg.size.x == 1
50
+ assert cfg.db.name == "name"
51
+ assert cfg.store.items == ["a", "b"]
52
+
53
+
54
+ def test_iter_params_from_config(cfg):
55
+ from hydraflow.config import iter_params
56
+
57
+ it = iter_params(cfg)
58
+ assert next(it) == ("size.x", 1)
59
+ assert next(it) == ("size.y", 2)
60
+ assert next(it) == ("db.name", "name")
61
+ assert next(it) == ("db.port", 100)
62
+ assert next(it) == ("store.items.0", "a")
63
+ assert next(it) == ("store.items.1", "b")
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ import subprocess
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import mlflow
8
+ import pytest
9
+ from mlflow.artifacts import download_artifacts
10
+ from mlflow.entities.run import Run
11
+
12
+
13
+ @pytest.fixture
14
+ def runs(monkeypatch, tmp_path):
15
+ file = Path("tests/log_run.py").absolute()
16
+ monkeypatch.chdir(tmp_path)
17
+
18
+ subprocess.check_call(
19
+ [sys.executable, file.as_posix(), "-m", "host=x,y", "port=1,2"]
20
+ )
21
+
22
+ mlflow.set_experiment("log_run")
23
+ runs = mlflow.search_runs(output_format="list")
24
+ assert len(runs) == 4
25
+ assert isinstance(runs, list)
26
+ yield runs
27
+
28
+
29
+ @pytest.fixture(params=range(4))
30
+ def run_id(runs, request):
31
+ run = runs[request.param]
32
+ assert isinstance(run, Run)
33
+ return run.info.run_id
34
+
35
+
36
+ def test_output(run_id: str):
37
+ path = download_artifacts(run_id=run_id, artifact_path="a.txt")
38
+ text = Path(path).read_text()
39
+ assert text == "abc"
40
+
41
+
42
+ def read_log(run_id: str) -> str:
43
+ path = download_artifacts(run_id=run_id, artifact_path="log_run.log")
44
+ text = Path(path).read_text()
45
+ assert "START" in text
46
+ assert "END" in text
47
+ return text
48
+
49
+
50
+ def test_load_config(run_id: str):
51
+ from hydraflow.run import load_config
52
+
53
+ log = read_log(run_id)
54
+ host, port = log.splitlines()[0].split("START,")[-1].split(",")
55
+
56
+ cfg = load_config(run_id)
57
+ assert cfg.host == host.strip()
58
+ assert cfg.port == int(port)
59
+
60
+
61
+ def test_load_config_err(run_id: str):
62
+ from hydraflow.run import load_config
63
+
64
+ assert not load_config(run_id, "a")
@@ -0,0 +1,188 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+
5
+ import mlflow
6
+ import pytest
7
+ from mlflow.entities import Run
8
+ from pandas import DataFrame, Series
9
+
10
+
11
+ @pytest.fixture
12
+ def _runs(monkeypatch, tmp_path):
13
+ monkeypatch.chdir(tmp_path)
14
+
15
+ mlflow.set_experiment("test_run")
16
+ for x in range(6):
17
+ with mlflow.start_run(run_name=f"{x}"):
18
+ mlflow.log_param("p", x)
19
+ mlflow.log_param("q", 0)
20
+ mlflow.log_text(f"{x}", "abc.txt")
21
+
22
+ x = mlflow.search_runs(output_format="list", order_by=["params.p"])
23
+ assert isinstance(x, list)
24
+ assert isinstance(x[0], Run)
25
+ y = mlflow.search_runs(output_format="pandas", order_by=["params.p"])
26
+ assert isinstance(y, DataFrame)
27
+ return x, y
28
+
29
+
30
+ @pytest.fixture(params=["list", "pandas"])
31
+ def runs(_runs: tuple[list[Run], DataFrame], request: pytest.FixtureRequest):
32
+ if request.param == "list":
33
+ return _runs[0]
34
+
35
+ return _runs[1]
36
+
37
+
38
+ @pytest.fixture
39
+ def runs_list(_runs: tuple[list[Run], DataFrame], request: pytest.FixtureRequest):
40
+ return _runs[0]
41
+
42
+
43
+ @pytest.fixture
44
+ def runs_df(_runs: tuple[list[Run], DataFrame], request: pytest.FixtureRequest):
45
+ return _runs[1]
46
+
47
+
48
+ def test_filter_by_config_one(runs: list[Run] | DataFrame):
49
+ from hydraflow.run import filter_by_config
50
+
51
+ assert len(runs) == 6
52
+ x = filter_by_config(runs, {"p": 1})
53
+ assert len(x) == 1
54
+
55
+
56
+ def test_filter_by_config_all(runs: list[Run] | DataFrame):
57
+ from hydraflow.run import filter_by_config
58
+
59
+ assert len(runs) == 6
60
+ x = filter_by_config(runs, {"q": 0})
61
+ assert len(x) == 6
62
+
63
+
64
+ def test_get_by_config_list(runs_list: list[Run]):
65
+ from hydraflow.run import get_by_config
66
+
67
+ run = get_by_config(runs_list, {"p": 4})
68
+ assert isinstance(run, Run)
69
+ assert run.data.params["p"] == "4"
70
+
71
+
72
+ def test_get_by_config_df(runs_df: DataFrame):
73
+ from hydraflow.run import get_by_config
74
+
75
+ run = get_by_config(runs_df, {"p": 2})
76
+ assert isinstance(run, Series)
77
+ assert run["params.p"] == "2"
78
+
79
+
80
+ def test_get_by_config_error(runs: list[Run] | DataFrame):
81
+ from hydraflow.run import get_by_config
82
+
83
+ with pytest.raises(ValueError):
84
+ get_by_config(runs, {"q": 0})
85
+
86
+
87
+ def test_drop_unique_params(runs_df):
88
+ from hydraflow.run import drop_unique_params
89
+
90
+ assert "params.p" in runs_df
91
+ assert "params.q" in runs_df
92
+ df = drop_unique_params(runs_df)
93
+ assert "params.p" in df
94
+ assert "params.q" not in df
95
+
96
+
97
+ def test_get_param_names(runs_df: DataFrame):
98
+ from hydraflow.run import get_param_names
99
+
100
+ params = get_param_names(runs_df)
101
+ assert len(params) == 2
102
+ assert "p" in params
103
+ assert "q" in params
104
+
105
+
106
+ def test_get_param_dict(runs_df: DataFrame):
107
+ from hydraflow.run import get_param_dict
108
+
109
+ params = get_param_dict(runs_df)
110
+ assert len(params["p"]) == 6
111
+ assert len(params["q"]) == 1
112
+
113
+
114
+ @pytest.mark.parametrize("i", range(6))
115
+ def test_get_run_id(i: int, runs_list: list[Run], runs_df: DataFrame):
116
+ from hydraflow.run import get_run_id
117
+
118
+ assert get_run_id(runs_list[i]) == get_run_id(runs_df.iloc[i])
119
+ assert get_run_id(runs_list[i]) == get_run_id(runs_df.iloc[i])
120
+
121
+ x = get_run_id(runs_list[i])
122
+ assert get_run_id(x) == runs_list[i].info.run_id
123
+
124
+
125
+ @pytest.mark.parametrize("i", range(6))
126
+ @pytest.mark.parametrize("path", [None, "a"])
127
+ def test_get_artifact_uri(i: int, path, runs_list: list[Run], runs_df: DataFrame):
128
+ from hydraflow.run import get_artifact_uri, get_run_id
129
+
130
+ x = get_run_id(runs_list[i])
131
+ y = get_artifact_uri(runs_list[i], path)
132
+ assert get_artifact_uri(x, path) == y
133
+ assert get_artifact_uri(runs_df.iloc[i], path) == y
134
+
135
+
136
+ @pytest.mark.parametrize("i", range(6))
137
+ def test_chdir_artifact_list(i: int, runs_list: list[Run]):
138
+ from hydraflow.context import chdir_artifact
139
+
140
+ with chdir_artifact(runs_list[i]):
141
+ assert Path("abc.txt").read_text() == f"{i}"
142
+
143
+ assert not Path("abc.txt").exists()
144
+
145
+
146
+ def test_hydra_output_dir_error(runs_list: list[Run]):
147
+ from hydraflow.run import get_hydra_output_dir
148
+
149
+ with pytest.raises(FileNotFoundError):
150
+ get_hydra_output_dir(runs_list[0])
151
+
152
+
153
+ @pytest.fixture
154
+ def df():
155
+ return DataFrame(
156
+ {
157
+ "a": [0, 0, 0, 0],
158
+ "params.x": [1, 1, 2, 2],
159
+ "params.y": [1, 2, 1, 2],
160
+ "params.z": [1, 1, 1, 1],
161
+ },
162
+ )
163
+
164
+
165
+ def test_unique_params(df):
166
+ from hydraflow.run import drop_unique_params
167
+
168
+ df = drop_unique_params(df)
169
+ assert len(df.columns) == 3
170
+ assert "a" in df
171
+ assert "params.x" in df
172
+ assert "params.z" not in df
173
+
174
+
175
+ def test_param_names(df):
176
+ from hydraflow.run import get_param_names
177
+
178
+ names = get_param_names(df)
179
+ assert names == ["x", "y", "z"]
180
+
181
+
182
+ def test_param_dict(df):
183
+ from hydraflow.run import get_param_dict
184
+
185
+ x = get_param_dict(df)
186
+ assert x["x"] == [1, 2]
187
+ assert x["y"] == [1, 2]
188
+ assert x["z"] == [1]
@@ -0,0 +1,5 @@
1
+ from importlib.metadata import version
2
+
3
+
4
+ def test_version():
5
+ assert version("hydraflow").count(".") == 2
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ import subprocess
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import pytest
8
+
9
+
10
+ @pytest.mark.parametrize("dir", [".", Path])
11
+ def test_watch(dir, monkeypatch, tmp_path):
12
+ from hydraflow.context import watch
13
+
14
+ file = Path("tests/watch.py").absolute()
15
+ monkeypatch.chdir(tmp_path)
16
+
17
+ lines = []
18
+
19
+ def func(path: Path) -> None:
20
+ k, t = path.read_text().split(" ")
21
+ lines.append([int(k), float(t), time.time(), path.name])
22
+
23
+ with watch(func, dir if isinstance(dir, str) else dir()):
24
+ subprocess.check_call(["python", file])
25
+
26
+ for k in range(4):
27
+ assert lines[k][0] == k
28
+ assert lines[k][-1] == f"{k}.txt"
29
+ assert 0 <= lines[k][2] - lines[k][1] < 0.05
@@ -0,0 +1,19 @@
1
+ import sys
2
+ import time
3
+ from pathlib import Path
4
+
5
+
6
+ def run():
7
+ for k in range(2):
8
+ print(k, time.time(), flush=True)
9
+ Path(f"{k}.txt").write_text(f"{k} {time.time()}")
10
+ time.sleep(1)
11
+
12
+ for k in range(2, 4):
13
+ print(k, time.time(), file=sys.stderr, flush=True)
14
+ Path(f"{k}.txt").write_text(f"{k} {time.time()}")
15
+ time.sleep(1)
16
+
17
+
18
+ if __name__ == "__main__":
19
+ run()