hydraflow 0.1.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- hydraflow-0.1.0/.devcontainer/devcontainer.json +16 -0
- hydraflow-0.1.0/.devcontainer/postCreate.sh +5 -0
- hydraflow-0.1.0/.devcontainer/starship.toml +29 -0
- hydraflow-0.1.0/.gitignore +4 -0
- hydraflow-0.1.0/LICENSE +21 -0
- hydraflow-0.1.0/PKG-INFO +29 -0
- hydraflow-0.1.0/README.md +1 -0
- hydraflow-0.1.0/pyproject.toml +67 -0
- hydraflow-0.1.0/src/hydraflow/__init__.py +28 -0
- hydraflow-0.1.0/src/hydraflow/config.py +30 -0
- hydraflow-0.1.0/src/hydraflow/context.py +110 -0
- hydraflow-0.1.0/src/hydraflow/mlflow.py +17 -0
- hydraflow-0.1.0/src/hydraflow/run.py +172 -0
- hydraflow-0.1.0/src/hydraflow/util.py +11 -0
- hydraflow-0.1.0/tests/log_run.py +37 -0
- hydraflow-0.1.0/tests/test_config.py +63 -0
- hydraflow-0.1.0/tests/test_log_run.py +64 -0
- hydraflow-0.1.0/tests/test_run.py +188 -0
- hydraflow-0.1.0/tests/test_version.py +5 -0
- hydraflow-0.1.0/tests/test_watch.py +29 -0
- hydraflow-0.1.0/tests/watch.py +19 -0
@@ -0,0 +1,16 @@
|
|
1
|
+
{
|
2
|
+
"name": "hydraflow",
|
3
|
+
"image": "mcr.microsoft.com/devcontainers/python:3.10-bookworm",
|
4
|
+
"features": {
|
5
|
+
"ghcr.io/devcontainers-contrib/features/ruff:1": {},
|
6
|
+
"ghcr.io/devcontainers-contrib/features/hatch:2": {},
|
7
|
+
"ghcr.io/devcontainers-contrib/features/starship:1": {},
|
8
|
+
"ghcr.io/va-h/devcontainers-features/uv:1": {}
|
9
|
+
},
|
10
|
+
"customizations": {
|
11
|
+
"vscode": {
|
12
|
+
"extensions": ["charliermarsh.ruff"]
|
13
|
+
}
|
14
|
+
},
|
15
|
+
"postCreateCommand": ".devcontainer/postCreate.sh"
|
16
|
+
}
|
@@ -0,0 +1,29 @@
|
|
1
|
+
"$schema" = 'https://starship.rs/config-schema.json'
|
2
|
+
|
3
|
+
add_newline = false
|
4
|
+
|
5
|
+
[username]
|
6
|
+
disabled = true
|
7
|
+
|
8
|
+
[hostname]
|
9
|
+
disabled = true
|
10
|
+
|
11
|
+
[package]
|
12
|
+
format = '[$symbol$version]($style) '
|
13
|
+
symbol = " "
|
14
|
+
style = 'bold blue'
|
15
|
+
|
16
|
+
[git_branch]
|
17
|
+
format = '[$symbol$branch(:$remote_branch)]($style)'
|
18
|
+
symbol = ' '
|
19
|
+
style = 'bold green'
|
20
|
+
|
21
|
+
[git_status]
|
22
|
+
style = 'red'
|
23
|
+
modified = '*'
|
24
|
+
format = '([$modified]($style)) '
|
25
|
+
|
26
|
+
[python]
|
27
|
+
format = '[${symbol}${pyenv_prefix}(${version} )(\($virtualenv\) )]($style)'
|
28
|
+
symbol = ' '
|
29
|
+
style = 'yellow'
|
hydraflow-0.1.0/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2024 Daizu
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
hydraflow-0.1.0/PKG-INFO
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: hydraflow
|
3
|
+
Version: 0.1.0
|
4
|
+
Summary: Hydra with MLflow
|
5
|
+
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
6
|
+
Project-URL: Issues, https://github.com/daizutabi/hydraflow/issues
|
7
|
+
Author-email: daizutabi <daizutabi@gmail.com>
|
8
|
+
License-Expression: MIT
|
9
|
+
License-File: LICENSE
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
11
|
+
Classifier: Programming Language :: Python
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
15
|
+
Classifier: Topic :: Documentation
|
16
|
+
Classifier: Topic :: Software Development :: Documentation
|
17
|
+
Requires-Python: >=3.10
|
18
|
+
Requires-Dist: hydra-core
|
19
|
+
Requires-Dist: mlflow
|
20
|
+
Requires-Dist: watchdog
|
21
|
+
Provides-Extra: dev
|
22
|
+
Requires-Dist: pytest-clarity; extra == 'dev'
|
23
|
+
Requires-Dist: pytest-cov; extra == 'dev'
|
24
|
+
Requires-Dist: pytest-randomly; extra == 'dev'
|
25
|
+
Requires-Dist: pytest-xdist; extra == 'dev'
|
26
|
+
Requires-Dist: setuptools; extra == 'dev'
|
27
|
+
Description-Content-Type: text/markdown
|
28
|
+
|
29
|
+
# hydraflow
|
@@ -0,0 +1 @@
|
|
1
|
+
# hydraflow
|
@@ -0,0 +1,67 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["hatchling"]
|
3
|
+
build-backend = "hatchling.build"
|
4
|
+
|
5
|
+
[project]
|
6
|
+
name = "hydraflow"
|
7
|
+
version = "0.1.0"
|
8
|
+
description = "Hydra with MLflow"
|
9
|
+
readme = "README.md"
|
10
|
+
license = "MIT"
|
11
|
+
authors = [{ name = "daizutabi", email = "daizutabi@gmail.com" }]
|
12
|
+
classifiers = [
|
13
|
+
"Development Status :: 4 - Beta",
|
14
|
+
"Programming Language :: Python",
|
15
|
+
"Programming Language :: Python :: 3.10",
|
16
|
+
"Programming Language :: Python :: 3.11",
|
17
|
+
"Programming Language :: Python :: 3.12",
|
18
|
+
"Topic :: Documentation",
|
19
|
+
"Topic :: Software Development :: Documentation",
|
20
|
+
]
|
21
|
+
requires-python = ">=3.10"
|
22
|
+
dependencies = ["hydra-core", "mlflow", "watchdog"]
|
23
|
+
|
24
|
+
[project.optional-dependencies]
|
25
|
+
dev = [
|
26
|
+
"pytest-clarity",
|
27
|
+
"pytest-cov",
|
28
|
+
"pytest-randomly",
|
29
|
+
"pytest-xdist",
|
30
|
+
"setuptools",
|
31
|
+
]
|
32
|
+
|
33
|
+
[project.urls]
|
34
|
+
# Documentation = "https://daizutabi.github.io/hydraflow/"
|
35
|
+
Source = "https://github.com/daizutabi/hydraflow"
|
36
|
+
Issues = "https://github.com/daizutabi/hydraflow/issues"
|
37
|
+
|
38
|
+
[tool.hatch.build.targets.sdist]
|
39
|
+
exclude = ["/.github", "/docs"]
|
40
|
+
|
41
|
+
[tool.hatch.build.targets.wheel]
|
42
|
+
packages = ["src/hydraflow"]
|
43
|
+
|
44
|
+
[tool.pytest.ini_options]
|
45
|
+
addopts = [
|
46
|
+
"--doctest-modules",
|
47
|
+
"--cov=hydraflow",
|
48
|
+
"--cov-report=lcov:lcov.info",
|
49
|
+
]
|
50
|
+
|
51
|
+
doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"]
|
52
|
+
filterwarnings = ['ignore:pkg_resources is deprecated:DeprecationWarning']
|
53
|
+
|
54
|
+
[tool.coverage.report]
|
55
|
+
exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
|
56
|
+
|
57
|
+
[tool.hatch.envs.docs.scripts]
|
58
|
+
build = "mkdocs build --clean --strict {args}"
|
59
|
+
serve = "mkdocs serve --dev-addr localhost:8000 {args}"
|
60
|
+
deploy = "mkdocs gh-deploy --force"
|
61
|
+
|
62
|
+
[tool.ruff]
|
63
|
+
line-length = 88
|
64
|
+
target-version = "py312"
|
65
|
+
|
66
|
+
[tool.ruff.lint]
|
67
|
+
unfixable = ["F401", "RUF100"]
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from .context import Info, chdir_artifact, log_run, watch
|
2
|
+
from .mlflow import set_experiment
|
3
|
+
from .run import (
|
4
|
+
filter_by_config,
|
5
|
+
get_artifact_dir,
|
6
|
+
get_artifact_path,
|
7
|
+
get_artifact_uri,
|
8
|
+
get_by_config,
|
9
|
+
get_param_dict,
|
10
|
+
get_param_names,
|
11
|
+
get_run_id,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"Info",
|
16
|
+
"chdir_artifact",
|
17
|
+
"filter_by_config",
|
18
|
+
"get_artifact_dir",
|
19
|
+
"get_artifact_path",
|
20
|
+
"get_artifact_uri",
|
21
|
+
"get_by_config",
|
22
|
+
"get_param_dict",
|
23
|
+
"get_param_names",
|
24
|
+
"get_run_id",
|
25
|
+
"log_run",
|
26
|
+
"set_experiment",
|
27
|
+
"watch",
|
28
|
+
]
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from collections.abc import Iterator
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
|
12
|
+
def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
13
|
+
if not isinstance(config, DictConfig | ListConfig):
|
14
|
+
config = OmegaConf.create(config) # type: ignore
|
15
|
+
|
16
|
+
if isinstance(config, DictConfig):
|
17
|
+
for key, value in config.items():
|
18
|
+
if isinstance(value, (DictConfig, ListConfig)):
|
19
|
+
yield from iter_params(value, f"{prefix}{key}.")
|
20
|
+
|
21
|
+
else:
|
22
|
+
yield f"{prefix}{key}", value
|
23
|
+
|
24
|
+
elif isinstance(config, ListConfig):
|
25
|
+
for index, value in enumerate(config):
|
26
|
+
if isinstance(value, (DictConfig, ListConfig)):
|
27
|
+
yield from iter_params(value, f"{prefix}{index}.")
|
28
|
+
|
29
|
+
else:
|
30
|
+
yield f"{prefix}{index}", value
|
@@ -0,0 +1,110 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
from contextlib import contextmanager
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import TYPE_CHECKING
|
9
|
+
|
10
|
+
import mlflow
|
11
|
+
from hydra.core.hydra_config import HydraConfig
|
12
|
+
from watchdog.events import FileModifiedEvent, FileSystemEventHandler
|
13
|
+
from watchdog.observers import Observer
|
14
|
+
|
15
|
+
from hydraflow.mlflow import log_params
|
16
|
+
from hydraflow.run import get_artifact_path
|
17
|
+
from hydraflow.util import uri_to_path
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from collections.abc import Callable, Iterator
|
21
|
+
|
22
|
+
from mlflow.entities.run import Run
|
23
|
+
from pandas import Series
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class Info:
|
28
|
+
output_dir: Path
|
29
|
+
artifact_dir: Path
|
30
|
+
|
31
|
+
|
32
|
+
@contextmanager
|
33
|
+
def log_run(
|
34
|
+
config: object,
|
35
|
+
*,
|
36
|
+
synchronous: bool | None = None,
|
37
|
+
) -> Iterator[Info]:
|
38
|
+
log_params(config, synchronous=synchronous)
|
39
|
+
|
40
|
+
hc = HydraConfig.get()
|
41
|
+
output_dir = Path(hc.runtime.output_dir)
|
42
|
+
uri = mlflow.get_artifact_uri()
|
43
|
+
location = Info(output_dir, uri_to_path(uri))
|
44
|
+
|
45
|
+
# Save '.hydra' config directory first.
|
46
|
+
output_subdir = output_dir / (hc.output_subdir or "")
|
47
|
+
mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
|
48
|
+
|
49
|
+
try:
|
50
|
+
yield location
|
51
|
+
|
52
|
+
finally:
|
53
|
+
# Save output_dir including '.hydra' config directory.
|
54
|
+
mlflow.log_artifacts(output_dir.as_posix())
|
55
|
+
|
56
|
+
|
57
|
+
@contextmanager
|
58
|
+
def watch(
|
59
|
+
func: Callable[[Path], None],
|
60
|
+
dir: Path | str = "",
|
61
|
+
timeout: int = 600,
|
62
|
+
) -> Iterator[None]:
|
63
|
+
if not dir:
|
64
|
+
uri = mlflow.get_artifact_uri()
|
65
|
+
dir = uri_to_path(uri)
|
66
|
+
|
67
|
+
handler = Handler(func)
|
68
|
+
observer = Observer()
|
69
|
+
observer.schedule(handler, dir, recursive=True)
|
70
|
+
observer.start()
|
71
|
+
|
72
|
+
try:
|
73
|
+
yield
|
74
|
+
|
75
|
+
finally:
|
76
|
+
elapsed = 0
|
77
|
+
while not observer.event_queue.empty():
|
78
|
+
time.sleep(0.2)
|
79
|
+
elapsed += 0.2
|
80
|
+
if elapsed > timeout:
|
81
|
+
break
|
82
|
+
|
83
|
+
observer.stop()
|
84
|
+
observer.join()
|
85
|
+
|
86
|
+
|
87
|
+
class Handler(FileSystemEventHandler):
|
88
|
+
def __init__(self, func: Callable[[Path], None]) -> None:
|
89
|
+
self.func = func
|
90
|
+
|
91
|
+
def on_modified(self, event: FileModifiedEvent) -> None:
|
92
|
+
file = Path(event.src_path)
|
93
|
+
if file.is_file():
|
94
|
+
self.func(file)
|
95
|
+
|
96
|
+
|
97
|
+
@contextmanager
|
98
|
+
def chdir_artifact(
|
99
|
+
run: Run | Series | str,
|
100
|
+
artifact_path: str | None = None,
|
101
|
+
) -> Iterator[Path]:
|
102
|
+
curdir = Path.cwd()
|
103
|
+
|
104
|
+
artifact_dir = get_artifact_path(run, artifact_path)
|
105
|
+
|
106
|
+
os.chdir(artifact_dir)
|
107
|
+
try:
|
108
|
+
yield artifact_dir
|
109
|
+
finally:
|
110
|
+
os.chdir(curdir)
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import mlflow
|
4
|
+
from hydra.core.hydra_config import HydraConfig
|
5
|
+
|
6
|
+
from hydraflow.config import iter_params
|
7
|
+
|
8
|
+
|
9
|
+
def set_experiment() -> None:
|
10
|
+
hc = HydraConfig.get()
|
11
|
+
mlflow.set_tracking_uri("")
|
12
|
+
mlflow.set_experiment(hc.job.name)
|
13
|
+
|
14
|
+
|
15
|
+
def log_params(config: object, *, synchronous: bool | None = None) -> None:
|
16
|
+
for key, value in iter_params(config):
|
17
|
+
mlflow.log_param(key, value, synchronous=synchronous)
|
@@ -0,0 +1,172 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import TYPE_CHECKING, Any, overload
|
5
|
+
|
6
|
+
import mlflow
|
7
|
+
import numpy as np
|
8
|
+
from mlflow.entities.run import Run
|
9
|
+
from mlflow.tracking import artifact_utils
|
10
|
+
from omegaconf import DictConfig, OmegaConf
|
11
|
+
|
12
|
+
from hydraflow.config import iter_params
|
13
|
+
from hydraflow.util import uri_to_path
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from typing import Any
|
17
|
+
|
18
|
+
from pandas import DataFrame, Series
|
19
|
+
|
20
|
+
|
21
|
+
@overload
|
22
|
+
def filter_by_config(runs: list[Run], config: object) -> list[Run]: ...
|
23
|
+
|
24
|
+
|
25
|
+
@overload
|
26
|
+
def filter_by_config(runs: DataFrame, config: object) -> DataFrame: ...
|
27
|
+
|
28
|
+
|
29
|
+
def filter_by_config(runs: list[Run] | DataFrame, config: object):
|
30
|
+
if isinstance(runs, list):
|
31
|
+
return filter_by_config_list(runs, config)
|
32
|
+
|
33
|
+
return filter_by_config_dataframe(runs, config)
|
34
|
+
|
35
|
+
|
36
|
+
def _is_equal(run: Run, key: str, value: Any) -> bool:
|
37
|
+
param = run.data.params.get(key, value)
|
38
|
+
if param is None:
|
39
|
+
return False
|
40
|
+
|
41
|
+
return type(value)(param) == value
|
42
|
+
|
43
|
+
|
44
|
+
def filter_by_config_list(runs: list[Run], config: object) -> list[Run]:
|
45
|
+
for key, value in iter_params(config):
|
46
|
+
runs = [run for run in runs if _is_equal(run, key, value)]
|
47
|
+
|
48
|
+
return runs
|
49
|
+
|
50
|
+
|
51
|
+
def filter_by_config_dataframe(runs: DataFrame, config: object) -> DataFrame:
|
52
|
+
index = np.ones(len(runs), dtype=bool)
|
53
|
+
|
54
|
+
for key, value in iter_params(config):
|
55
|
+
name = f"params.{key}"
|
56
|
+
if name in runs:
|
57
|
+
series = runs[name]
|
58
|
+
is_value = -series.isna()
|
59
|
+
param = series.fillna(value).astype(type(value))
|
60
|
+
index &= is_value & (param == value)
|
61
|
+
|
62
|
+
return runs[index]
|
63
|
+
|
64
|
+
|
65
|
+
@overload
|
66
|
+
def get_by_config(runs: list[Run], config: object) -> Run: ...
|
67
|
+
|
68
|
+
|
69
|
+
@overload
|
70
|
+
def get_by_config(runs: DataFrame, config: object) -> Series: ...
|
71
|
+
|
72
|
+
|
73
|
+
def get_by_config(runs: list[Run] | DataFrame, config: object):
|
74
|
+
runs = filter_by_config(runs, config)
|
75
|
+
|
76
|
+
if len(runs) == 1:
|
77
|
+
return runs[0] if isinstance(runs, list) else runs.iloc[0]
|
78
|
+
|
79
|
+
msg = f"filtered runs has not length of 1.: {len(runs)}"
|
80
|
+
raise ValueError(msg)
|
81
|
+
|
82
|
+
|
83
|
+
def drop_unique_params(runs: DataFrame) -> DataFrame:
|
84
|
+
def select(column: str) -> bool:
|
85
|
+
return not column.startswith("params.") or len(runs[column].unique()) > 1
|
86
|
+
|
87
|
+
columns = [select(column) for column in runs.columns]
|
88
|
+
return runs.iloc[:, columns]
|
89
|
+
|
90
|
+
|
91
|
+
def get_param_names(runs: DataFrame) -> list[str]:
|
92
|
+
def get_name(column: str) -> str:
|
93
|
+
if column.startswith("params."):
|
94
|
+
return column.split(".", maxsplit=1)[-1]
|
95
|
+
|
96
|
+
return ""
|
97
|
+
|
98
|
+
columns = [get_name(column) for column in runs.columns]
|
99
|
+
return [column for column in columns if column]
|
100
|
+
|
101
|
+
|
102
|
+
def get_param_dict(runs: DataFrame) -> dict[str, list[str]]:
|
103
|
+
params = {}
|
104
|
+
for name in get_param_names(runs):
|
105
|
+
params[name] = list(runs[f"params.{name}"].unique())
|
106
|
+
|
107
|
+
return params
|
108
|
+
|
109
|
+
|
110
|
+
def get_run_id(run: Run | Series | str) -> str:
|
111
|
+
if isinstance(run, Run):
|
112
|
+
return run.info.run_id
|
113
|
+
if isinstance(run, str):
|
114
|
+
return run
|
115
|
+
return run.run_id
|
116
|
+
|
117
|
+
|
118
|
+
def get_artifact_uri(run: Run | Series | str, artifact_path: str | None = None) -> str:
|
119
|
+
if isinstance(run, Run):
|
120
|
+
uri = run.info.artifact_uri
|
121
|
+
elif isinstance(run, str):
|
122
|
+
uri = artifact_utils.get_artifact_uri(run_id=run)
|
123
|
+
else:
|
124
|
+
uri = run.artifact_uri
|
125
|
+
|
126
|
+
if artifact_path:
|
127
|
+
uri = f"{uri}/{artifact_path}"
|
128
|
+
|
129
|
+
return uri # type: ignore
|
130
|
+
|
131
|
+
|
132
|
+
def get_artifact_dir(run: Run | Series | str) -> Path:
|
133
|
+
uri = get_artifact_uri(run)
|
134
|
+
return uri_to_path(uri)
|
135
|
+
|
136
|
+
|
137
|
+
def get_artifact_path(
|
138
|
+
run: Run | Series | str,
|
139
|
+
artifact_path: str | None = None,
|
140
|
+
) -> Path:
|
141
|
+
artifact_dir = get_artifact_dir(run)
|
142
|
+
return artifact_dir / artifact_path if artifact_path else artifact_dir
|
143
|
+
|
144
|
+
|
145
|
+
def load_config(run: Run | Series | str, output_subdir: str = ".hydra") -> DictConfig:
|
146
|
+
run_id = get_run_id(run)
|
147
|
+
|
148
|
+
try:
|
149
|
+
path = mlflow.artifacts.download_artifacts(
|
150
|
+
run_id=run_id,
|
151
|
+
artifact_path=f"{output_subdir}/config.yaml",
|
152
|
+
)
|
153
|
+
except OSError:
|
154
|
+
return DictConfig({})
|
155
|
+
|
156
|
+
return OmegaConf.load(path) # type: ignore
|
157
|
+
|
158
|
+
|
159
|
+
def get_hydra_output_dir(run: Run | Series | str) -> Path:
|
160
|
+
path = get_artifact_dir(run) / ".hydra/hydra.yaml"
|
161
|
+
|
162
|
+
if path.exists():
|
163
|
+
hc = OmegaConf.load(path)
|
164
|
+
return Path(hc.hydra.runtime.output_dir)
|
165
|
+
|
166
|
+
raise FileNotFoundError
|
167
|
+
|
168
|
+
|
169
|
+
def log_hydra_output_dir(run: Run | Series | str) -> None:
|
170
|
+
output_dir = get_hydra_output_dir(run)
|
171
|
+
run_id = run if isinstance(run, str) else run.info.run_id
|
172
|
+
mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
import hydra
|
7
|
+
import mlflow
|
8
|
+
from hydra.core.config_store import ConfigStore
|
9
|
+
|
10
|
+
from hydraflow.context import log_run
|
11
|
+
|
12
|
+
log = logging.getLogger(__name__)
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class MySQLConfig:
|
17
|
+
host: str = "localhost"
|
18
|
+
port: int = 3306
|
19
|
+
|
20
|
+
|
21
|
+
cs = ConfigStore.instance()
|
22
|
+
cs.store(name="config", node=MySQLConfig)
|
23
|
+
|
24
|
+
|
25
|
+
@hydra.main(version_base=None, config_name="config")
|
26
|
+
def app(cfg: MySQLConfig):
|
27
|
+
mlflow.set_experiment("log_run")
|
28
|
+
with mlflow.start_run(), log_run(cfg) as info:
|
29
|
+
log.info(f"START, {cfg.host}, {cfg.port} ")
|
30
|
+
mlflow.log_text(info.artifact_dir.as_posix(), "artifact_dir.txt")
|
31
|
+
mlflow.log_text(info.output_dir.as_posix(), "output_dir.txt")
|
32
|
+
(info.artifact_dir / "a.txt").write_text("abc")
|
33
|
+
log.info("END")
|
34
|
+
|
35
|
+
|
36
|
+
if __name__ == "__main__":
|
37
|
+
app()
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
from omegaconf import OmegaConf
|
5
|
+
|
6
|
+
|
7
|
+
def test_iter_params():
|
8
|
+
from hydraflow.config import iter_params
|
9
|
+
|
10
|
+
conf = OmegaConf.create({"k": "v", "l": [1, {"a": "1", "b": "2", 3: "c"}]})
|
11
|
+
it = iter_params(conf)
|
12
|
+
assert next(it) == ("k", "v")
|
13
|
+
assert next(it) == ("l.0", 1)
|
14
|
+
assert next(it) == ("l.1.a", "1")
|
15
|
+
assert next(it) == ("l.1.b", "2")
|
16
|
+
assert next(it) == ("l.1.3", "c")
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class Size:
|
21
|
+
x: int = 1
|
22
|
+
y: int = 2
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class Db:
|
27
|
+
name: str = "name"
|
28
|
+
port: int = 100
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class Store:
|
33
|
+
items: list[str] = field(default_factory=lambda: ["a", "b"])
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class Config:
|
38
|
+
size: Size = field(default_factory=Size)
|
39
|
+
db: Db = field(default_factory=Db)
|
40
|
+
store: Store = field(default_factory=Store)
|
41
|
+
|
42
|
+
|
43
|
+
@pytest.fixture
|
44
|
+
def cfg():
|
45
|
+
return Config()
|
46
|
+
|
47
|
+
|
48
|
+
def test_config(cfg: Config):
|
49
|
+
assert cfg.size.x == 1
|
50
|
+
assert cfg.db.name == "name"
|
51
|
+
assert cfg.store.items == ["a", "b"]
|
52
|
+
|
53
|
+
|
54
|
+
def test_iter_params_from_config(cfg):
|
55
|
+
from hydraflow.config import iter_params
|
56
|
+
|
57
|
+
it = iter_params(cfg)
|
58
|
+
assert next(it) == ("size.x", 1)
|
59
|
+
assert next(it) == ("size.y", 2)
|
60
|
+
assert next(it) == ("db.name", "name")
|
61
|
+
assert next(it) == ("db.port", 100)
|
62
|
+
assert next(it) == ("store.items.0", "a")
|
63
|
+
assert next(it) == ("store.items.1", "b")
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import subprocess
|
4
|
+
import sys
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import mlflow
|
8
|
+
import pytest
|
9
|
+
from mlflow.artifacts import download_artifacts
|
10
|
+
from mlflow.entities.run import Run
|
11
|
+
|
12
|
+
|
13
|
+
@pytest.fixture
|
14
|
+
def runs(monkeypatch, tmp_path):
|
15
|
+
file = Path("tests/log_run.py").absolute()
|
16
|
+
monkeypatch.chdir(tmp_path)
|
17
|
+
|
18
|
+
subprocess.check_call(
|
19
|
+
[sys.executable, file.as_posix(), "-m", "host=x,y", "port=1,2"]
|
20
|
+
)
|
21
|
+
|
22
|
+
mlflow.set_experiment("log_run")
|
23
|
+
runs = mlflow.search_runs(output_format="list")
|
24
|
+
assert len(runs) == 4
|
25
|
+
assert isinstance(runs, list)
|
26
|
+
yield runs
|
27
|
+
|
28
|
+
|
29
|
+
@pytest.fixture(params=range(4))
|
30
|
+
def run_id(runs, request):
|
31
|
+
run = runs[request.param]
|
32
|
+
assert isinstance(run, Run)
|
33
|
+
return run.info.run_id
|
34
|
+
|
35
|
+
|
36
|
+
def test_output(run_id: str):
|
37
|
+
path = download_artifacts(run_id=run_id, artifact_path="a.txt")
|
38
|
+
text = Path(path).read_text()
|
39
|
+
assert text == "abc"
|
40
|
+
|
41
|
+
|
42
|
+
def read_log(run_id: str) -> str:
|
43
|
+
path = download_artifacts(run_id=run_id, artifact_path="log_run.log")
|
44
|
+
text = Path(path).read_text()
|
45
|
+
assert "START" in text
|
46
|
+
assert "END" in text
|
47
|
+
return text
|
48
|
+
|
49
|
+
|
50
|
+
def test_load_config(run_id: str):
|
51
|
+
from hydraflow.run import load_config
|
52
|
+
|
53
|
+
log = read_log(run_id)
|
54
|
+
host, port = log.splitlines()[0].split("START,")[-1].split(",")
|
55
|
+
|
56
|
+
cfg = load_config(run_id)
|
57
|
+
assert cfg.host == host.strip()
|
58
|
+
assert cfg.port == int(port)
|
59
|
+
|
60
|
+
|
61
|
+
def test_load_config_err(run_id: str):
|
62
|
+
from hydraflow.run import load_config
|
63
|
+
|
64
|
+
assert not load_config(run_id, "a")
|
@@ -0,0 +1,188 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import mlflow
|
6
|
+
import pytest
|
7
|
+
from mlflow.entities import Run
|
8
|
+
from pandas import DataFrame, Series
|
9
|
+
|
10
|
+
|
11
|
+
@pytest.fixture
|
12
|
+
def _runs(monkeypatch, tmp_path):
|
13
|
+
monkeypatch.chdir(tmp_path)
|
14
|
+
|
15
|
+
mlflow.set_experiment("test_run")
|
16
|
+
for x in range(6):
|
17
|
+
with mlflow.start_run(run_name=f"{x}"):
|
18
|
+
mlflow.log_param("p", x)
|
19
|
+
mlflow.log_param("q", 0)
|
20
|
+
mlflow.log_text(f"{x}", "abc.txt")
|
21
|
+
|
22
|
+
x = mlflow.search_runs(output_format="list", order_by=["params.p"])
|
23
|
+
assert isinstance(x, list)
|
24
|
+
assert isinstance(x[0], Run)
|
25
|
+
y = mlflow.search_runs(output_format="pandas", order_by=["params.p"])
|
26
|
+
assert isinstance(y, DataFrame)
|
27
|
+
return x, y
|
28
|
+
|
29
|
+
|
30
|
+
@pytest.fixture(params=["list", "pandas"])
|
31
|
+
def runs(_runs: tuple[list[Run], DataFrame], request: pytest.FixtureRequest):
|
32
|
+
if request.param == "list":
|
33
|
+
return _runs[0]
|
34
|
+
|
35
|
+
return _runs[1]
|
36
|
+
|
37
|
+
|
38
|
+
@pytest.fixture
|
39
|
+
def runs_list(_runs: tuple[list[Run], DataFrame], request: pytest.FixtureRequest):
|
40
|
+
return _runs[0]
|
41
|
+
|
42
|
+
|
43
|
+
@pytest.fixture
|
44
|
+
def runs_df(_runs: tuple[list[Run], DataFrame], request: pytest.FixtureRequest):
|
45
|
+
return _runs[1]
|
46
|
+
|
47
|
+
|
48
|
+
def test_filter_by_config_one(runs: list[Run] | DataFrame):
|
49
|
+
from hydraflow.run import filter_by_config
|
50
|
+
|
51
|
+
assert len(runs) == 6
|
52
|
+
x = filter_by_config(runs, {"p": 1})
|
53
|
+
assert len(x) == 1
|
54
|
+
|
55
|
+
|
56
|
+
def test_filter_by_config_all(runs: list[Run] | DataFrame):
|
57
|
+
from hydraflow.run import filter_by_config
|
58
|
+
|
59
|
+
assert len(runs) == 6
|
60
|
+
x = filter_by_config(runs, {"q": 0})
|
61
|
+
assert len(x) == 6
|
62
|
+
|
63
|
+
|
64
|
+
def test_get_by_config_list(runs_list: list[Run]):
|
65
|
+
from hydraflow.run import get_by_config
|
66
|
+
|
67
|
+
run = get_by_config(runs_list, {"p": 4})
|
68
|
+
assert isinstance(run, Run)
|
69
|
+
assert run.data.params["p"] == "4"
|
70
|
+
|
71
|
+
|
72
|
+
def test_get_by_config_df(runs_df: DataFrame):
|
73
|
+
from hydraflow.run import get_by_config
|
74
|
+
|
75
|
+
run = get_by_config(runs_df, {"p": 2})
|
76
|
+
assert isinstance(run, Series)
|
77
|
+
assert run["params.p"] == "2"
|
78
|
+
|
79
|
+
|
80
|
+
def test_get_by_config_error(runs: list[Run] | DataFrame):
|
81
|
+
from hydraflow.run import get_by_config
|
82
|
+
|
83
|
+
with pytest.raises(ValueError):
|
84
|
+
get_by_config(runs, {"q": 0})
|
85
|
+
|
86
|
+
|
87
|
+
def test_drop_unique_params(runs_df):
|
88
|
+
from hydraflow.run import drop_unique_params
|
89
|
+
|
90
|
+
assert "params.p" in runs_df
|
91
|
+
assert "params.q" in runs_df
|
92
|
+
df = drop_unique_params(runs_df)
|
93
|
+
assert "params.p" in df
|
94
|
+
assert "params.q" not in df
|
95
|
+
|
96
|
+
|
97
|
+
def test_get_param_names(runs_df: DataFrame):
|
98
|
+
from hydraflow.run import get_param_names
|
99
|
+
|
100
|
+
params = get_param_names(runs_df)
|
101
|
+
assert len(params) == 2
|
102
|
+
assert "p" in params
|
103
|
+
assert "q" in params
|
104
|
+
|
105
|
+
|
106
|
+
def test_get_param_dict(runs_df: DataFrame):
|
107
|
+
from hydraflow.run import get_param_dict
|
108
|
+
|
109
|
+
params = get_param_dict(runs_df)
|
110
|
+
assert len(params["p"]) == 6
|
111
|
+
assert len(params["q"]) == 1
|
112
|
+
|
113
|
+
|
114
|
+
@pytest.mark.parametrize("i", range(6))
|
115
|
+
def test_get_run_id(i: int, runs_list: list[Run], runs_df: DataFrame):
|
116
|
+
from hydraflow.run import get_run_id
|
117
|
+
|
118
|
+
assert get_run_id(runs_list[i]) == get_run_id(runs_df.iloc[i])
|
119
|
+
assert get_run_id(runs_list[i]) == get_run_id(runs_df.iloc[i])
|
120
|
+
|
121
|
+
x = get_run_id(runs_list[i])
|
122
|
+
assert get_run_id(x) == runs_list[i].info.run_id
|
123
|
+
|
124
|
+
|
125
|
+
@pytest.mark.parametrize("i", range(6))
|
126
|
+
@pytest.mark.parametrize("path", [None, "a"])
|
127
|
+
def test_get_artifact_uri(i: int, path, runs_list: list[Run], runs_df: DataFrame):
|
128
|
+
from hydraflow.run import get_artifact_uri, get_run_id
|
129
|
+
|
130
|
+
x = get_run_id(runs_list[i])
|
131
|
+
y = get_artifact_uri(runs_list[i], path)
|
132
|
+
assert get_artifact_uri(x, path) == y
|
133
|
+
assert get_artifact_uri(runs_df.iloc[i], path) == y
|
134
|
+
|
135
|
+
|
136
|
+
@pytest.mark.parametrize("i", range(6))
|
137
|
+
def test_chdir_artifact_list(i: int, runs_list: list[Run]):
|
138
|
+
from hydraflow.context import chdir_artifact
|
139
|
+
|
140
|
+
with chdir_artifact(runs_list[i]):
|
141
|
+
assert Path("abc.txt").read_text() == f"{i}"
|
142
|
+
|
143
|
+
assert not Path("abc.txt").exists()
|
144
|
+
|
145
|
+
|
146
|
+
def test_hydra_output_dir_error(runs_list: list[Run]):
|
147
|
+
from hydraflow.run import get_hydra_output_dir
|
148
|
+
|
149
|
+
with pytest.raises(FileNotFoundError):
|
150
|
+
get_hydra_output_dir(runs_list[0])
|
151
|
+
|
152
|
+
|
153
|
+
@pytest.fixture
|
154
|
+
def df():
|
155
|
+
return DataFrame(
|
156
|
+
{
|
157
|
+
"a": [0, 0, 0, 0],
|
158
|
+
"params.x": [1, 1, 2, 2],
|
159
|
+
"params.y": [1, 2, 1, 2],
|
160
|
+
"params.z": [1, 1, 1, 1],
|
161
|
+
},
|
162
|
+
)
|
163
|
+
|
164
|
+
|
165
|
+
def test_unique_params(df):
|
166
|
+
from hydraflow.run import drop_unique_params
|
167
|
+
|
168
|
+
df = drop_unique_params(df)
|
169
|
+
assert len(df.columns) == 3
|
170
|
+
assert "a" in df
|
171
|
+
assert "params.x" in df
|
172
|
+
assert "params.z" not in df
|
173
|
+
|
174
|
+
|
175
|
+
def test_param_names(df):
|
176
|
+
from hydraflow.run import get_param_names
|
177
|
+
|
178
|
+
names = get_param_names(df)
|
179
|
+
assert names == ["x", "y", "z"]
|
180
|
+
|
181
|
+
|
182
|
+
def test_param_dict(df):
|
183
|
+
from hydraflow.run import get_param_dict
|
184
|
+
|
185
|
+
x = get_param_dict(df)
|
186
|
+
assert x["x"] == [1, 2]
|
187
|
+
assert x["y"] == [1, 2]
|
188
|
+
assert x["z"] == [1]
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import subprocess
|
4
|
+
import time
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import pytest
|
8
|
+
|
9
|
+
|
10
|
+
@pytest.mark.parametrize("dir", [".", Path])
|
11
|
+
def test_watch(dir, monkeypatch, tmp_path):
|
12
|
+
from hydraflow.context import watch
|
13
|
+
|
14
|
+
file = Path("tests/watch.py").absolute()
|
15
|
+
monkeypatch.chdir(tmp_path)
|
16
|
+
|
17
|
+
lines = []
|
18
|
+
|
19
|
+
def func(path: Path) -> None:
|
20
|
+
k, t = path.read_text().split(" ")
|
21
|
+
lines.append([int(k), float(t), time.time(), path.name])
|
22
|
+
|
23
|
+
with watch(func, dir if isinstance(dir, str) else dir()):
|
24
|
+
subprocess.check_call(["python", file])
|
25
|
+
|
26
|
+
for k in range(4):
|
27
|
+
assert lines[k][0] == k
|
28
|
+
assert lines[k][-1] == f"{k}.txt"
|
29
|
+
assert 0 <= lines[k][2] - lines[k][1] < 0.05
|
@@ -0,0 +1,19 @@
|
|
1
|
+
import sys
|
2
|
+
import time
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
|
6
|
+
def run():
|
7
|
+
for k in range(2):
|
8
|
+
print(k, time.time(), flush=True)
|
9
|
+
Path(f"{k}.txt").write_text(f"{k} {time.time()}")
|
10
|
+
time.sleep(1)
|
11
|
+
|
12
|
+
for k in range(2, 4):
|
13
|
+
print(k, time.time(), file=sys.stderr, flush=True)
|
14
|
+
Path(f"{k}.txt").write_text(f"{k} {time.time()}")
|
15
|
+
time.sleep(1)
|
16
|
+
|
17
|
+
|
18
|
+
if __name__ == "__main__":
|
19
|
+
run()
|