hydraflow 0.1.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
hydraflow/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ from .context import Info, chdir_artifact, log_run, watch
2
+ from .mlflow import set_experiment
3
+ from .run import (
4
+ filter_by_config,
5
+ get_artifact_dir,
6
+ get_artifact_path,
7
+ get_artifact_uri,
8
+ get_by_config,
9
+ get_param_dict,
10
+ get_param_names,
11
+ get_run_id,
12
+ )
13
+
14
+ __all__ = [
15
+ "Info",
16
+ "chdir_artifact",
17
+ "filter_by_config",
18
+ "get_artifact_dir",
19
+ "get_artifact_path",
20
+ "get_artifact_uri",
21
+ "get_by_config",
22
+ "get_param_dict",
23
+ "get_param_names",
24
+ "get_run_id",
25
+ "log_run",
26
+ "set_experiment",
27
+ "watch",
28
+ ]
hydraflow/config.py ADDED
@@ -0,0 +1,30 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from omegaconf import DictConfig, ListConfig, OmegaConf
6
+
7
+ if TYPE_CHECKING:
8
+ from collections.abc import Iterator
9
+ from typing import Any
10
+
11
+
12
+ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
13
+ if not isinstance(config, DictConfig | ListConfig):
14
+ config = OmegaConf.create(config) # type: ignore
15
+
16
+ if isinstance(config, DictConfig):
17
+ for key, value in config.items():
18
+ if isinstance(value, (DictConfig, ListConfig)):
19
+ yield from iter_params(value, f"{prefix}{key}.")
20
+
21
+ else:
22
+ yield f"{prefix}{key}", value
23
+
24
+ elif isinstance(config, ListConfig):
25
+ for index, value in enumerate(config):
26
+ if isinstance(value, (DictConfig, ListConfig)):
27
+ yield from iter_params(value, f"{prefix}{index}.")
28
+
29
+ else:
30
+ yield f"{prefix}{index}", value
hydraflow/context.py ADDED
@@ -0,0 +1,110 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from contextlib import contextmanager
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import TYPE_CHECKING
9
+
10
+ import mlflow
11
+ from hydra.core.hydra_config import HydraConfig
12
+ from watchdog.events import FileModifiedEvent, FileSystemEventHandler
13
+ from watchdog.observers import Observer
14
+
15
+ from hydraflow.mlflow import log_params
16
+ from hydraflow.run import get_artifact_path
17
+ from hydraflow.util import uri_to_path
18
+
19
+ if TYPE_CHECKING:
20
+ from collections.abc import Callable, Iterator
21
+
22
+ from mlflow.entities.run import Run
23
+ from pandas import Series
24
+
25
+
26
+ @dataclass
27
+ class Info:
28
+ output_dir: Path
29
+ artifact_dir: Path
30
+
31
+
32
+ @contextmanager
33
+ def log_run(
34
+ config: object,
35
+ *,
36
+ synchronous: bool | None = None,
37
+ ) -> Iterator[Info]:
38
+ log_params(config, synchronous=synchronous)
39
+
40
+ hc = HydraConfig.get()
41
+ output_dir = Path(hc.runtime.output_dir)
42
+ uri = mlflow.get_artifact_uri()
43
+ location = Info(output_dir, uri_to_path(uri))
44
+
45
+ # Save '.hydra' config directory first.
46
+ output_subdir = output_dir / (hc.output_subdir or "")
47
+ mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
48
+
49
+ try:
50
+ yield location
51
+
52
+ finally:
53
+ # Save output_dir including '.hydra' config directory.
54
+ mlflow.log_artifacts(output_dir.as_posix())
55
+
56
+
57
+ @contextmanager
58
+ def watch(
59
+ func: Callable[[Path], None],
60
+ dir: Path | str = "",
61
+ timeout: int = 600,
62
+ ) -> Iterator[None]:
63
+ if not dir:
64
+ uri = mlflow.get_artifact_uri()
65
+ dir = uri_to_path(uri)
66
+
67
+ handler = Handler(func)
68
+ observer = Observer()
69
+ observer.schedule(handler, dir, recursive=True)
70
+ observer.start()
71
+
72
+ try:
73
+ yield
74
+
75
+ finally:
76
+ elapsed = 0
77
+ while not observer.event_queue.empty():
78
+ time.sleep(0.2)
79
+ elapsed += 0.2
80
+ if elapsed > timeout:
81
+ break
82
+
83
+ observer.stop()
84
+ observer.join()
85
+
86
+
87
+ class Handler(FileSystemEventHandler):
88
+ def __init__(self, func: Callable[[Path], None]) -> None:
89
+ self.func = func
90
+
91
+ def on_modified(self, event: FileModifiedEvent) -> None:
92
+ file = Path(event.src_path)
93
+ if file.is_file():
94
+ self.func(file)
95
+
96
+
97
+ @contextmanager
98
+ def chdir_artifact(
99
+ run: Run | Series | str,
100
+ artifact_path: str | None = None,
101
+ ) -> Iterator[Path]:
102
+ curdir = Path.cwd()
103
+
104
+ artifact_dir = get_artifact_path(run, artifact_path)
105
+
106
+ os.chdir(artifact_dir)
107
+ try:
108
+ yield artifact_dir
109
+ finally:
110
+ os.chdir(curdir)
hydraflow/mlflow.py ADDED
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ import mlflow
4
+ from hydra.core.hydra_config import HydraConfig
5
+
6
+ from hydraflow.config import iter_params
7
+
8
+
9
+ def set_experiment() -> None:
10
+ hc = HydraConfig.get()
11
+ mlflow.set_tracking_uri("")
12
+ mlflow.set_experiment(hc.job.name)
13
+
14
+
15
+ def log_params(config: object, *, synchronous: bool | None = None) -> None:
16
+ for key, value in iter_params(config):
17
+ mlflow.log_param(key, value, synchronous=synchronous)
hydraflow/run.py ADDED
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Any, overload
5
+
6
+ import mlflow
7
+ import numpy as np
8
+ from mlflow.entities.run import Run
9
+ from mlflow.tracking import artifact_utils
10
+ from omegaconf import DictConfig, OmegaConf
11
+
12
+ from hydraflow.config import iter_params
13
+ from hydraflow.util import uri_to_path
14
+
15
+ if TYPE_CHECKING:
16
+ from typing import Any
17
+
18
+ from pandas import DataFrame, Series
19
+
20
+
21
+ @overload
22
+ def filter_by_config(runs: list[Run], config: object) -> list[Run]: ...
23
+
24
+
25
+ @overload
26
+ def filter_by_config(runs: DataFrame, config: object) -> DataFrame: ...
27
+
28
+
29
+ def filter_by_config(runs: list[Run] | DataFrame, config: object):
30
+ if isinstance(runs, list):
31
+ return filter_by_config_list(runs, config)
32
+
33
+ return filter_by_config_dataframe(runs, config)
34
+
35
+
36
+ def _is_equal(run: Run, key: str, value: Any) -> bool:
37
+ param = run.data.params.get(key, value)
38
+ if param is None:
39
+ return False
40
+
41
+ return type(value)(param) == value
42
+
43
+
44
+ def filter_by_config_list(runs: list[Run], config: object) -> list[Run]:
45
+ for key, value in iter_params(config):
46
+ runs = [run for run in runs if _is_equal(run, key, value)]
47
+
48
+ return runs
49
+
50
+
51
+ def filter_by_config_dataframe(runs: DataFrame, config: object) -> DataFrame:
52
+ index = np.ones(len(runs), dtype=bool)
53
+
54
+ for key, value in iter_params(config):
55
+ name = f"params.{key}"
56
+ if name in runs:
57
+ series = runs[name]
58
+ is_value = -series.isna()
59
+ param = series.fillna(value).astype(type(value))
60
+ index &= is_value & (param == value)
61
+
62
+ return runs[index]
63
+
64
+
65
+ @overload
66
+ def get_by_config(runs: list[Run], config: object) -> Run: ...
67
+
68
+
69
+ @overload
70
+ def get_by_config(runs: DataFrame, config: object) -> Series: ...
71
+
72
+
73
+ def get_by_config(runs: list[Run] | DataFrame, config: object):
74
+ runs = filter_by_config(runs, config)
75
+
76
+ if len(runs) == 1:
77
+ return runs[0] if isinstance(runs, list) else runs.iloc[0]
78
+
79
+ msg = f"filtered runs has not length of 1.: {len(runs)}"
80
+ raise ValueError(msg)
81
+
82
+
83
+ def drop_unique_params(runs: DataFrame) -> DataFrame:
84
+ def select(column: str) -> bool:
85
+ return not column.startswith("params.") or len(runs[column].unique()) > 1
86
+
87
+ columns = [select(column) for column in runs.columns]
88
+ return runs.iloc[:, columns]
89
+
90
+
91
+ def get_param_names(runs: DataFrame) -> list[str]:
92
+ def get_name(column: str) -> str:
93
+ if column.startswith("params."):
94
+ return column.split(".", maxsplit=1)[-1]
95
+
96
+ return ""
97
+
98
+ columns = [get_name(column) for column in runs.columns]
99
+ return [column for column in columns if column]
100
+
101
+
102
+ def get_param_dict(runs: DataFrame) -> dict[str, list[str]]:
103
+ params = {}
104
+ for name in get_param_names(runs):
105
+ params[name] = list(runs[f"params.{name}"].unique())
106
+
107
+ return params
108
+
109
+
110
+ def get_run_id(run: Run | Series | str) -> str:
111
+ if isinstance(run, Run):
112
+ return run.info.run_id
113
+ if isinstance(run, str):
114
+ return run
115
+ return run.run_id
116
+
117
+
118
+ def get_artifact_uri(run: Run | Series | str, artifact_path: str | None = None) -> str:
119
+ if isinstance(run, Run):
120
+ uri = run.info.artifact_uri
121
+ elif isinstance(run, str):
122
+ uri = artifact_utils.get_artifact_uri(run_id=run)
123
+ else:
124
+ uri = run.artifact_uri
125
+
126
+ if artifact_path:
127
+ uri = f"{uri}/{artifact_path}"
128
+
129
+ return uri # type: ignore
130
+
131
+
132
+ def get_artifact_dir(run: Run | Series | str) -> Path:
133
+ uri = get_artifact_uri(run)
134
+ return uri_to_path(uri)
135
+
136
+
137
+ def get_artifact_path(
138
+ run: Run | Series | str,
139
+ artifact_path: str | None = None,
140
+ ) -> Path:
141
+ artifact_dir = get_artifact_dir(run)
142
+ return artifact_dir / artifact_path if artifact_path else artifact_dir
143
+
144
+
145
+ def load_config(run: Run | Series | str, output_subdir: str = ".hydra") -> DictConfig:
146
+ run_id = get_run_id(run)
147
+
148
+ try:
149
+ path = mlflow.artifacts.download_artifacts(
150
+ run_id=run_id,
151
+ artifact_path=f"{output_subdir}/config.yaml",
152
+ )
153
+ except OSError:
154
+ return DictConfig({})
155
+
156
+ return OmegaConf.load(path) # type: ignore
157
+
158
+
159
+ def get_hydra_output_dir(run: Run | Series | str) -> Path:
160
+ path = get_artifact_dir(run) / ".hydra/hydra.yaml"
161
+
162
+ if path.exists():
163
+ hc = OmegaConf.load(path)
164
+ return Path(hc.hydra.runtime.output_dir)
165
+
166
+ raise FileNotFoundError
167
+
168
+
169
+ def log_hydra_output_dir(run: Run | Series | str) -> None:
170
+ output_dir = get_hydra_output_dir(run)
171
+ run_id = run if isinstance(run, str) else run.info.run_id
172
+ mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
hydraflow/util.py ADDED
@@ -0,0 +1,11 @@
1
+ import platform
2
+ from pathlib import Path
3
+ from urllib.parse import urlparse
4
+
5
+
6
+ def uri_to_path(uri: str) -> Path:
7
+ path = urlparse(uri).path
8
+ if platform.system() == "Windows" and path.startswith("/"):
9
+ path = path[1:]
10
+
11
+ return Path(path)
@@ -0,0 +1,29 @@
1
+ Metadata-Version: 2.3
2
+ Name: hydraflow
3
+ Version: 0.1.0
4
+ Summary: Hydra with MLflow
5
+ Project-URL: Source, https://github.com/daizutabi/hydraflow
6
+ Project-URL: Issues, https://github.com/daizutabi/hydraflow/issues
7
+ Author-email: daizutabi <daizutabi@gmail.com>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: Programming Language :: Python
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Documentation
16
+ Classifier: Topic :: Software Development :: Documentation
17
+ Requires-Python: >=3.10
18
+ Requires-Dist: hydra-core
19
+ Requires-Dist: mlflow
20
+ Requires-Dist: watchdog
21
+ Provides-Extra: dev
22
+ Requires-Dist: pytest-clarity; extra == 'dev'
23
+ Requires-Dist: pytest-cov; extra == 'dev'
24
+ Requires-Dist: pytest-randomly; extra == 'dev'
25
+ Requires-Dist: pytest-xdist; extra == 'dev'
26
+ Requires-Dist: setuptools; extra == 'dev'
27
+ Description-Content-Type: text/markdown
28
+
29
+ # hydraflow
@@ -0,0 +1,10 @@
1
+ hydraflow/__init__.py,sha256=9RaHPTloOOJYPUKKfPuK_wxKDr_J9A3rJ_gr-bLABD0,559
2
+ hydraflow/config.py,sha256=b3Plh_lmq94loZNw9QP2asd6thCLyTzzYSutH0cONXA,964
3
+ hydraflow/context.py,sha256=zBmbZWNLxUF2IDDPregPnR_sh3utmFwFJaneSsBsLDM,2558
4
+ hydraflow/mlflow.py,sha256=yDZ_oB1IZdCNNqHm_0LxdZ1Nld28IkW8Xl7NMhWLApE,453
5
+ hydraflow/run.py,sha256=XTAD_fd-ivvZ4tbjQLHrf6u5eAGRrrhqvExiZQcFnX8,4591
6
+ hydraflow/util.py,sha256=HTymDLqa2UzCw3kNjqHDaAZNdRMnrEAWhCJ7_ZD7ffA,264
7
+ hydraflow-0.1.0.dist-info/METADATA,sha256=WuryvAC_8MrC-UerPqbvcWxgBn9ABrnysQ0aRYimw3A,1021
8
+ hydraflow-0.1.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ hydraflow-0.1.0.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
10
+ hydraflow-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.25.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Daizu
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.