hydraflow 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hydraflow/__init__.py CHANGED
@@ -1,27 +1,35 @@
1
1
  from .context import Info, chdir_artifact, log_run, watch
2
2
  from .mlflow import set_experiment
3
- from .run import (
4
- filter_by_config,
3
+ from .runs import (
4
+ Run,
5
+ Runs,
6
+ drop_unique_params,
7
+ filter_runs,
5
8
  get_artifact_dir,
6
9
  get_artifact_path,
7
10
  get_artifact_uri,
8
- get_by_config,
9
11
  get_param_dict,
10
12
  get_param_names,
13
+ get_run,
11
14
  get_run_id,
15
+ load_config,
12
16
  )
13
17
 
14
18
  __all__ = [
15
19
  "Info",
20
+ "Run",
21
+ "Runs",
16
22
  "chdir_artifact",
17
- "filter_by_config",
23
+ "drop_unique_params",
24
+ "filter_runs",
18
25
  "get_artifact_dir",
19
26
  "get_artifact_path",
20
27
  "get_artifact_uri",
21
- "get_by_config",
22
28
  "get_param_dict",
23
29
  "get_param_names",
30
+ "get_run",
24
31
  "get_run_id",
32
+ "load_config",
25
33
  "log_run",
26
34
  "set_experiment",
27
35
  "watch",
hydraflow/context.py CHANGED
@@ -13,7 +13,7 @@ from watchdog.events import FileModifiedEvent, FileSystemEventHandler
13
13
  from watchdog.observers import Observer
14
14
 
15
15
  from hydraflow.mlflow import log_params
16
- from hydraflow.run import get_artifact_path
16
+ from hydraflow.runs import get_artifact_path
17
17
  from hydraflow.util import uri_to_path
18
18
 
19
19
  if TYPE_CHECKING:
@@ -40,14 +40,19 @@ def log_run(
40
40
  hc = HydraConfig.get()
41
41
  output_dir = Path(hc.runtime.output_dir)
42
42
  uri = mlflow.get_artifact_uri()
43
- location = Info(output_dir, uri_to_path(uri))
43
+ info = Info(output_dir, uri_to_path(uri))
44
44
 
45
45
  # Save '.hydra' config directory first.
46
46
  output_subdir = output_dir / (hc.output_subdir or "")
47
47
  mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
48
48
 
49
+ def log_artifact(path: Path) -> None:
50
+ local_path = (output_dir / path).as_posix()
51
+ mlflow.log_artifact(local_path)
52
+
49
53
  try:
50
- yield location
54
+ with watch(log_artifact, output_dir):
55
+ yield info
51
56
 
52
57
  finally:
53
58
  # Save output_dir including '.hydra' config directory.
@@ -55,11 +60,7 @@ def log_run(
55
60
 
56
61
 
57
62
  @contextmanager
58
- def watch(
59
- func: Callable[[Path], None],
60
- dir: Path | str = "",
61
- timeout: int = 600,
62
- ) -> Iterator[None]:
63
+ def watch(func: Callable[[Path], None], dir: Path | str = "", timeout: int = 60) -> Iterator[None]:
63
64
  if not dir:
64
65
  uri = mlflow.get_artifact_uri()
65
66
  dir = uri_to_path(uri)
hydraflow/mlflow.py CHANGED
@@ -6,10 +6,13 @@ from hydra.core.hydra_config import HydraConfig
6
6
  from hydraflow.config import iter_params
7
7
 
8
8
 
9
- def set_experiment() -> None:
9
+ def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -> None:
10
+ if uri:
11
+ mlflow.set_tracking_uri(uri)
12
+
10
13
  hc = HydraConfig.get()
11
- mlflow.set_tracking_uri("")
12
- mlflow.set_experiment(hc.job.name)
14
+ name = f"{prefix}{hc.job.name}{suffix}"
15
+ mlflow.set_experiment(name)
13
16
 
14
17
 
15
18
  def log_params(config: object, *, synchronous: bool | None = None) -> None:
hydraflow/runs.py ADDED
@@ -0,0 +1,217 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from functools import cache
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import mlflow
9
+ import numpy as np
10
+ from mlflow.entities.run import Run as Run_
11
+ from mlflow.tracking import artifact_utils
12
+ from omegaconf import DictConfig, OmegaConf
13
+ from pandas import DataFrame, Series
14
+
15
+ from hydraflow.config import iter_params
16
+ from hydraflow.util import uri_to_path
17
+
18
+ if TYPE_CHECKING:
19
+ from typing import Any
20
+
21
+
22
+ @dataclass
23
+ class Runs:
24
+ runs: list[Run_] | DataFrame
25
+
26
+ def __repr__(self) -> str:
27
+ return f"{self.__class__.__name__}({len(self)})"
28
+
29
+ def __len__(self) -> int:
30
+ return len(self.runs)
31
+
32
+ def filter(self, config: object) -> Runs:
33
+ return Runs(filter_runs(self.runs, config))
34
+
35
+ def get(self, config: object) -> Run:
36
+ return Run(get_run(self.runs, config))
37
+
38
+ def drop_unique_params(self) -> Runs:
39
+ if isinstance(self.runs, DataFrame):
40
+ return Runs(drop_unique_params(self.runs))
41
+
42
+ raise NotImplementedError
43
+
44
+ def get_param_names(self) -> list[str]:
45
+ if isinstance(self.runs, DataFrame):
46
+ return get_param_names(self.runs)
47
+
48
+ raise NotImplementedError
49
+
50
+ def get_param_dict(self) -> dict[str, list[str]]:
51
+ if isinstance(self.runs, DataFrame):
52
+ return get_param_dict(self.runs)
53
+
54
+ raise NotImplementedError
55
+
56
+
57
+ def filter_runs(runs: list[Run_] | DataFrame, config: object) -> list[Run_] | DataFrame:
58
+ if isinstance(runs, list):
59
+ return filter_runs_list(runs, config)
60
+
61
+ return filter_runs_dataframe(runs, config)
62
+
63
+
64
+ def _is_equal(run: Run_, key: str, value: Any) -> bool:
65
+ param = run.data.params.get(key, value)
66
+
67
+ if param is None:
68
+ return False
69
+
70
+ return type(value)(param) == value
71
+
72
+
73
+ def filter_runs_list(runs: list[Run_], config: object) -> list[Run_]:
74
+ for key, value in iter_params(config):
75
+ runs = [run for run in runs if _is_equal(run, key, value)]
76
+
77
+ return runs
78
+
79
+
80
+ def filter_runs_dataframe(runs: DataFrame, config: object) -> DataFrame:
81
+ index = np.ones(len(runs), dtype=bool)
82
+
83
+ for key, value in iter_params(config):
84
+ name = f"params.{key}"
85
+
86
+ if name in runs:
87
+ series = runs[name]
88
+ is_value = -series.isna()
89
+ param = series.fillna(value).astype(type(value))
90
+ index &= is_value & (param == value)
91
+
92
+ return runs[index]
93
+
94
+
95
+ def get_run(runs: list[Run_] | DataFrame, config: object) -> Run_ | Series:
96
+ runs = filter_runs(runs, config)
97
+
98
+ if len(runs) == 1:
99
+ return runs[0] if isinstance(runs, list) else runs.iloc[0]
100
+
101
+ msg = f"number of filtered runs is not 1: got {len(runs)}"
102
+ raise ValueError(msg)
103
+
104
+
105
+ def drop_unique_params(runs: DataFrame) -> DataFrame:
106
+ def select(column: str) -> bool:
107
+ return not column.startswith("params.") or len(runs[column].unique()) > 1
108
+
109
+ columns = [select(column) for column in runs.columns]
110
+ return runs.iloc[:, columns]
111
+
112
+
113
+ def get_param_names(runs: DataFrame) -> list[str]:
114
+ def get_name(column: str) -> str:
115
+ if column.startswith("params."):
116
+ return column.split(".", maxsplit=1)[-1]
117
+
118
+ return ""
119
+
120
+ columns = [get_name(column) for column in runs.columns]
121
+ return [column for column in columns if column]
122
+
123
+
124
+ def get_param_dict(runs: DataFrame) -> dict[str, list[str]]:
125
+ params = {}
126
+ for name in get_param_names(runs):
127
+ params[name] = list(runs[f"params.{name}"].unique())
128
+
129
+ return params
130
+
131
+
132
+ @dataclass
133
+ class Run:
134
+ run: Run_ | Series | str
135
+
136
+ def __repr__(self) -> str:
137
+ return f"{self.__class__.__name__}({self.run_id!r})"
138
+
139
+ @property
140
+ def run_id(self) -> str:
141
+ return get_run_id(self.run)
142
+
143
+ def artifact_uri(self, artifact_path: str | None = None) -> str:
144
+ return get_artifact_uri(self.run, artifact_path)
145
+
146
+ @property
147
+ def artifact_dir(self) -> Path:
148
+ return get_artifact_dir(self.run)
149
+
150
+ def artifact_path(self, artifact_path: str | None = None) -> Path:
151
+ return get_artifact_path(self.run, artifact_path)
152
+
153
+ @property
154
+ def config(self) -> DictConfig:
155
+ return load_config(self.run)
156
+
157
+ def log_hydra_output_dir(self) -> None:
158
+ log_hydra_output_dir(self.run)
159
+
160
+
161
+ def get_run_id(run: Run_ | Series | str) -> str:
162
+ if isinstance(run, str):
163
+ return run
164
+
165
+ if isinstance(run, Run_):
166
+ return run.info.run_id
167
+
168
+ return run.run_id
169
+
170
+
171
+ def get_artifact_uri(run: Run_ | Series | str, artifact_path: str | None = None) -> str:
172
+ run_id = get_run_id(run)
173
+ return artifact_utils.get_artifact_uri(run_id, artifact_path)
174
+
175
+
176
+ def get_artifact_dir(run: Run_ | Series | str) -> Path:
177
+ uri = get_artifact_uri(run)
178
+ return uri_to_path(uri)
179
+
180
+
181
+ def get_artifact_path(run: Run_ | Series | str, artifact_path: str | None = None) -> Path:
182
+ artifact_dir = get_artifact_dir(run)
183
+ return artifact_dir / artifact_path if artifact_path else artifact_dir
184
+
185
+
186
+ def load_config(run: Run_ | Series | str) -> DictConfig:
187
+ run_id = get_run_id(run)
188
+ return _load_config(run_id)
189
+
190
+
191
+ @cache
192
+ def _load_config(run_id: str) -> DictConfig:
193
+ try:
194
+ path = mlflow.artifacts.download_artifacts(
195
+ run_id=run_id,
196
+ artifact_path=".hydra/config.yaml",
197
+ )
198
+ except OSError:
199
+ return DictConfig({})
200
+
201
+ return OmegaConf.load(path) # type: ignore
202
+
203
+
204
+ def get_hydra_output_dir(run: Run_ | Series | str) -> Path:
205
+ path = get_artifact_dir(run) / ".hydra/hydra.yaml"
206
+
207
+ if path.exists():
208
+ hc = OmegaConf.load(path)
209
+ return Path(hc.hydra.runtime.output_dir)
210
+
211
+ raise FileNotFoundError
212
+
213
+
214
+ def log_hydra_output_dir(run: Run_ | Series | str) -> None:
215
+ output_dir = get_hydra_output_dir(run)
216
+ run_id = run if isinstance(run, str) else run.info.run_id
217
+ mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.1.1
3
+ Version: 0.1.4
4
4
  Summary: Hydra with MLflow
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -0,0 +1,10 @@
1
+ hydraflow/__init__.py,sha256=e1Q0Sskx39jaU2zkGNXjFWNC5xugEz_hDERTN_6Mzy8,666
2
+ hydraflow/config.py,sha256=b3Plh_lmq94loZNw9QP2asd6thCLyTzzYSutH0cONXA,964
3
+ hydraflow/context.py,sha256=3vejDbRYQBuBwlhpBpOv5aoyZ-yS8UUzpbCFK1V1uvw,2720
4
+ hydraflow/mlflow.py,sha256=unBP3Y7ujTM3E_Hq_eYvRVFZoGfTA7B0h4FkOZtPPqc,566
5
+ hydraflow/runs.py,sha256=127YykWzmiNUUuJSGPOCZasXmd6tcE15HU32j8x71ck,5864
6
+ hydraflow/util.py,sha256=_BdOMq5tKPm8HOehb2s2ZIBpJYyVpvO_yaAIxbSj51I,253
7
+ hydraflow-0.1.4.dist-info/METADATA,sha256=Xw-xcDKdzkHa7bKDZUI6MXpOKekcyFbMyBy1yANjNQs,1903
8
+ hydraflow-0.1.4.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ hydraflow-0.1.4.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
10
+ hydraflow-0.1.4.dist-info/RECORD,,
hydraflow/run.py DELETED
@@ -1,172 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from pathlib import Path
4
- from typing import TYPE_CHECKING, Any, overload
5
-
6
- import mlflow
7
- import numpy as np
8
- from mlflow.entities.run import Run
9
- from mlflow.tracking import artifact_utils
10
- from omegaconf import DictConfig, OmegaConf
11
-
12
- from hydraflow.config import iter_params
13
- from hydraflow.util import uri_to_path
14
-
15
- if TYPE_CHECKING:
16
- from typing import Any
17
-
18
- from pandas import DataFrame, Series
19
-
20
-
21
- @overload
22
- def filter_by_config(runs: list[Run], config: object) -> list[Run]: ...
23
-
24
-
25
- @overload
26
- def filter_by_config(runs: DataFrame, config: object) -> DataFrame: ...
27
-
28
-
29
- def filter_by_config(runs: list[Run] | DataFrame, config: object):
30
- if isinstance(runs, list):
31
- return filter_by_config_list(runs, config)
32
-
33
- return filter_by_config_dataframe(runs, config)
34
-
35
-
36
- def _is_equal(run: Run, key: str, value: Any) -> bool:
37
- param = run.data.params.get(key, value)
38
- if param is None:
39
- return False
40
-
41
- return type(value)(param) == value
42
-
43
-
44
- def filter_by_config_list(runs: list[Run], config: object) -> list[Run]:
45
- for key, value in iter_params(config):
46
- runs = [run for run in runs if _is_equal(run, key, value)]
47
-
48
- return runs
49
-
50
-
51
- def filter_by_config_dataframe(runs: DataFrame, config: object) -> DataFrame:
52
- index = np.ones(len(runs), dtype=bool)
53
-
54
- for key, value in iter_params(config):
55
- name = f"params.{key}"
56
- if name in runs:
57
- series = runs[name]
58
- is_value = -series.isna()
59
- param = series.fillna(value).astype(type(value))
60
- index &= is_value & (param == value)
61
-
62
- return runs[index]
63
-
64
-
65
- @overload
66
- def get_by_config(runs: list[Run], config: object) -> Run: ...
67
-
68
-
69
- @overload
70
- def get_by_config(runs: DataFrame, config: object) -> Series: ...
71
-
72
-
73
- def get_by_config(runs: list[Run] | DataFrame, config: object):
74
- runs = filter_by_config(runs, config)
75
-
76
- if len(runs) == 1:
77
- return runs[0] if isinstance(runs, list) else runs.iloc[0]
78
-
79
- msg = f"filtered runs has not length of 1.: {len(runs)}"
80
- raise ValueError(msg)
81
-
82
-
83
- def drop_unique_params(runs: DataFrame) -> DataFrame:
84
- def select(column: str) -> bool:
85
- return not column.startswith("params.") or len(runs[column].unique()) > 1
86
-
87
- columns = [select(column) for column in runs.columns]
88
- return runs.iloc[:, columns]
89
-
90
-
91
- def get_param_names(runs: DataFrame) -> list[str]:
92
- def get_name(column: str) -> str:
93
- if column.startswith("params."):
94
- return column.split(".", maxsplit=1)[-1]
95
-
96
- return ""
97
-
98
- columns = [get_name(column) for column in runs.columns]
99
- return [column for column in columns if column]
100
-
101
-
102
- def get_param_dict(runs: DataFrame) -> dict[str, list[str]]:
103
- params = {}
104
- for name in get_param_names(runs):
105
- params[name] = list(runs[f"params.{name}"].unique())
106
-
107
- return params
108
-
109
-
110
- def get_run_id(run: Run | Series | str) -> str:
111
- if isinstance(run, Run):
112
- return run.info.run_id
113
- if isinstance(run, str):
114
- return run
115
- return run.run_id
116
-
117
-
118
- def get_artifact_uri(run: Run | Series | str, artifact_path: str | None = None) -> str:
119
- if isinstance(run, Run):
120
- uri = run.info.artifact_uri
121
- elif isinstance(run, str):
122
- uri = artifact_utils.get_artifact_uri(run_id=run)
123
- else:
124
- uri = run.artifact_uri
125
-
126
- if artifact_path:
127
- uri = f"{uri}/{artifact_path}"
128
-
129
- return uri # type: ignore
130
-
131
-
132
- def get_artifact_dir(run: Run | Series | str) -> Path:
133
- uri = get_artifact_uri(run)
134
- return uri_to_path(uri)
135
-
136
-
137
- def get_artifact_path(
138
- run: Run | Series | str,
139
- artifact_path: str | None = None,
140
- ) -> Path:
141
- artifact_dir = get_artifact_dir(run)
142
- return artifact_dir / artifact_path if artifact_path else artifact_dir
143
-
144
-
145
- def load_config(run: Run | Series | str, output_subdir: str = ".hydra") -> DictConfig:
146
- run_id = get_run_id(run)
147
-
148
- try:
149
- path = mlflow.artifacts.download_artifacts(
150
- run_id=run_id,
151
- artifact_path=f"{output_subdir}/config.yaml",
152
- )
153
- except OSError:
154
- return DictConfig({})
155
-
156
- return OmegaConf.load(path) # type: ignore
157
-
158
-
159
- def get_hydra_output_dir(run: Run | Series | str) -> Path:
160
- path = get_artifact_dir(run) / ".hydra/hydra.yaml"
161
-
162
- if path.exists():
163
- hc = OmegaConf.load(path)
164
- return Path(hc.hydra.runtime.output_dir)
165
-
166
- raise FileNotFoundError
167
-
168
-
169
- def log_hydra_output_dir(run: Run | Series | str) -> None:
170
- output_dir = get_hydra_output_dir(run)
171
- run_id = run if isinstance(run, str) else run.info.run_id
172
- mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
@@ -1,10 +0,0 @@
1
- hydraflow/__init__.py,sha256=9RaHPTloOOJYPUKKfPuK_wxKDr_J9A3rJ_gr-bLABD0,559
2
- hydraflow/config.py,sha256=b3Plh_lmq94loZNw9QP2asd6thCLyTzzYSutH0cONXA,964
3
- hydraflow/context.py,sha256=zBmbZWNLxUF2IDDPregPnR_sh3utmFwFJaneSsBsLDM,2558
4
- hydraflow/mlflow.py,sha256=yDZ_oB1IZdCNNqHm_0LxdZ1Nld28IkW8Xl7NMhWLApE,453
5
- hydraflow/run.py,sha256=XTAD_fd-ivvZ4tbjQLHrf6u5eAGRrrhqvExiZQcFnX8,4591
6
- hydraflow/util.py,sha256=_BdOMq5tKPm8HOehb2s2ZIBpJYyVpvO_yaAIxbSj51I,253
7
- hydraflow-0.1.1.dist-info/METADATA,sha256=4QeC8CONrWskor7MylsUtq5lUMr1L33wQiFAn10urfI,1903
8
- hydraflow-0.1.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- hydraflow-0.1.1.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
10
- hydraflow-0.1.1.dist-info/RECORD,,