hydraflow 0.5.1__tar.gz → 0.5.2__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.5.1 → hydraflow-0.5.2}/PKG-INFO +1 -1
- {hydraflow-0.5.1 → hydraflow-0.5.2}/pyproject.toml +2 -4
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/__init__.py +2 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/mlflow.py +6 -1
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/run_info.py +2 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/utils.py +24 -2
- hydraflow-0.5.2/tests/config/config.py +33 -0
- hydraflow-0.5.2/tests/config/test_config.py +29 -0
- hydraflow-0.5.2/tests/config/test_hydra.py +23 -0
- hydraflow-0.5.1/tests/test_config.py → hydraflow-0.5.2/tests/config/test_params.py +2 -66
- hydraflow-0.5.2/tests/conftest.py +81 -0
- hydraflow-0.5.2/tests/context/__init__.py +0 -0
- hydraflow-0.5.2/tests/context/context.py +44 -0
- hydraflow-0.5.2/tests/context/test_hydra.py +31 -0
- hydraflow-0.5.2/tests/param/__init__.py +0 -0
- hydraflow-0.5.2/tests/param/param.py +37 -0
- hydraflow-0.5.2/tests/param/test_hydra.py +33 -0
- hydraflow-0.5.2/tests/param/test_param.py +148 -0
- hydraflow-0.5.2/tests/run/__init__.py +0 -0
- hydraflow-0.5.2/tests/run/filter.py +33 -0
- hydraflow-0.5.2/tests/run/run.py +37 -0
- hydraflow-0.5.1/tests/test_run_collection.py → hydraflow-0.5.2/tests/run/test_collection.py +38 -13
- hydraflow-0.5.1/tests/test_run_data.py → hydraflow-0.5.2/tests/run/test_data.py +7 -4
- hydraflow-0.5.2/tests/run/test_filter.py +17 -0
- hydraflow-0.5.2/tests/run/test_hydra.py +52 -0
- hydraflow-0.5.1/tests/test_run_info.py → hydraflow-0.5.2/tests/run/test_info.py +7 -4
- hydraflow-0.5.2/tests/test_mlflow.py +91 -0
- hydraflow-0.5.2/tests/utils/__init__.py +0 -0
- hydraflow-0.5.2/tests/utils/test_hydra.py +50 -0
- hydraflow-0.5.2/tests/utils/test_run.py +47 -0
- hydraflow-0.5.2/tests/utils/utils.py +36 -0
- hydraflow-0.5.1/tests/apps/app.py +0 -70
- hydraflow-0.5.1/tests/conftest.py +0 -16
- hydraflow-0.5.1/tests/test_app.py +0 -213
- hydraflow-0.5.1/tests/test_context.py +0 -68
- hydraflow-0.5.1/tests/test_log_run.py +0 -89
- hydraflow-0.5.1/tests/test_mlflow.py +0 -35
- hydraflow-0.5.1/tests/test_param.py +0 -119
- hydraflow-0.5.1/tests/test_utils.py +0 -32
- {hydraflow-0.5.1 → hydraflow-0.5.2}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/.gitattributes +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/.gitignore +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/LICENSE +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/README.md +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/apps/quickstart.py +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/mkdocs.yml +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/config.py +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/context.py +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/param.py +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/py.typed +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/run_collection.py +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/src/hydraflow/run_data.py +0 -0
- {hydraflow-0.5.1 → hydraflow-0.5.2}/tests/__init__.py +0 -0
- {hydraflow-0.5.1/tests/apps → hydraflow-0.5.2/tests/config}/__init__.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.2
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://github.com/daizutabi/hydraflow
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.5.
|
7
|
+
version = "0.5.2"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = { file = "LICENSE" }
|
@@ -32,10 +32,7 @@ dev-dependencies = [
|
|
32
32
|
"mkapi",
|
33
33
|
"mkdocs-material",
|
34
34
|
"mkdocs>=1.6",
|
35
|
-
"pytest-clarity",
|
36
35
|
"pytest-cov",
|
37
|
-
"pytest-randomly",
|
38
|
-
"pytest-xdist",
|
39
36
|
]
|
40
37
|
|
41
38
|
[tool.hatch.build.targets.sdist]
|
@@ -96,6 +93,7 @@ exclude = ["tests/scripts/*.py"]
|
|
96
93
|
"PT",
|
97
94
|
"S",
|
98
95
|
"SIM117",
|
96
|
+
"TID",
|
99
97
|
"SLF",
|
100
98
|
]
|
101
99
|
"apps/*.py" = ["INP", "D", "G", "T"]
|
@@ -6,6 +6,7 @@ from .mlflow import list_runs, search_runs, set_experiment
|
|
6
6
|
from .run_collection import RunCollection
|
7
7
|
from .utils import (
|
8
8
|
get_artifact_dir,
|
9
|
+
get_artifact_path,
|
9
10
|
get_hydra_output_dir,
|
10
11
|
get_overrides,
|
11
12
|
load_config,
|
@@ -18,6 +19,7 @@ __all__ = [
|
|
18
19
|
"chdir_artifact",
|
19
20
|
"chdir_hydra_output",
|
20
21
|
"get_artifact_dir",
|
22
|
+
"get_artifact_path",
|
21
23
|
"get_hydra_output_dir",
|
22
24
|
"get_overrides",
|
23
25
|
"list_runs",
|
@@ -37,6 +37,7 @@ def set_experiment(
|
|
37
37
|
prefix: str = "",
|
38
38
|
suffix: str = "",
|
39
39
|
uri: str | Path | None = None,
|
40
|
+
name: str | None = None,
|
40
41
|
) -> Experiment:
|
41
42
|
"""Set the experiment name and tracking URI optionally.
|
42
43
|
|
@@ -48,6 +49,7 @@ def set_experiment(
|
|
48
49
|
prefix (str): The prefix to prepend to the experiment name.
|
49
50
|
suffix (str): The suffix to append to the experiment name.
|
50
51
|
uri (str | Path | None): The tracking URI to use. Defaults to None.
|
52
|
+
name (str | None): The name of the experiment. Defaults to None.
|
51
53
|
|
52
54
|
Returns:
|
53
55
|
Experiment: An instance of `mlflow.entities.Experiment` representing
|
@@ -57,6 +59,9 @@ def set_experiment(
|
|
57
59
|
if uri is not None:
|
58
60
|
mlflow.set_tracking_uri(uri)
|
59
61
|
|
62
|
+
if name is not None:
|
63
|
+
return mlflow.set_experiment(name)
|
64
|
+
|
60
65
|
hc = HydraConfig.get()
|
61
66
|
name = f"{prefix}{hc.job.name}{suffix}"
|
62
67
|
return mlflow.set_experiment(name)
|
@@ -214,7 +219,7 @@ def _list_runs(
|
|
214
219
|
elif Path(loc).is_dir():
|
215
220
|
path = Path(loc)
|
216
221
|
else:
|
217
|
-
continue
|
222
|
+
continue # no cov
|
218
223
|
|
219
224
|
run_ids.extend(file.stem for file in path.iterdir() if file.is_dir())
|
220
225
|
|
@@ -30,10 +30,32 @@ def get_artifact_dir(run: Run | None = None) -> Path:
|
|
30
30
|
"""
|
31
31
|
uri = mlflow.get_artifact_uri() if run is None else run.info.artifact_uri
|
32
32
|
|
33
|
-
if not
|
33
|
+
if not isinstance(uri, str):
|
34
34
|
raise NotImplementedError
|
35
35
|
|
36
|
-
|
36
|
+
if uri.startswith("file://"):
|
37
|
+
return Path(mlflow.artifacts.download_artifacts(uri))
|
38
|
+
|
39
|
+
if Path(uri).is_dir():
|
40
|
+
return Path(uri)
|
41
|
+
|
42
|
+
raise NotImplementedError
|
43
|
+
|
44
|
+
|
45
|
+
def get_artifact_path(run: Run | None, path: str) -> Path:
|
46
|
+
"""Retrieve the artifact path for the given run and path.
|
47
|
+
|
48
|
+
This function uses MLflow to get the artifact path for the given run and path.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
run (Run | None): The run object. Defaults to None.
|
52
|
+
path (str): The path to the artifact.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
The local path to the artifact.
|
56
|
+
|
57
|
+
"""
|
58
|
+
return get_artifact_dir(run) / path
|
37
59
|
|
38
60
|
|
39
61
|
def get_hydra_output_dir(run: Run | None = None) -> Path:
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
import hydra
|
6
|
+
import mlflow
|
7
|
+
from hydra.core.config_store import ConfigStore
|
8
|
+
|
9
|
+
import hydraflow
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass
|
13
|
+
class Config:
|
14
|
+
name: str = "a"
|
15
|
+
age: int = 1
|
16
|
+
height: float = 1.7
|
17
|
+
|
18
|
+
|
19
|
+
cs = ConfigStore.instance()
|
20
|
+
cs.store(name="config", node=Config)
|
21
|
+
|
22
|
+
|
23
|
+
@hydra.main(version_base=None, config_name="config")
|
24
|
+
def app(cfg: Config):
|
25
|
+
hydraflow.set_experiment()
|
26
|
+
|
27
|
+
with hydraflow.start_run(cfg):
|
28
|
+
overrides = hydraflow.select_overrides(Config(name="x", height=2))
|
29
|
+
mlflow.log_text(str(overrides), "overrides.txt")
|
30
|
+
|
31
|
+
|
32
|
+
if __name__ == "__main__":
|
33
|
+
app()
|
@@ -0,0 +1,29 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
|
3
|
+
|
4
|
+
@dataclass
|
5
|
+
class C:
|
6
|
+
z: int = 3
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class B:
|
11
|
+
y: int = 2
|
12
|
+
c: C = field(default_factory=C)
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class A:
|
17
|
+
x: int = 1
|
18
|
+
b: B = field(default_factory=B)
|
19
|
+
|
20
|
+
|
21
|
+
def test_select_config():
|
22
|
+
from hydraflow.config import select_config
|
23
|
+
|
24
|
+
a = A()
|
25
|
+
assert select_config(a, ["x"]) == {"x": 1}
|
26
|
+
assert select_config(a, ["b.y"]) == {"b.y": 2}
|
27
|
+
assert select_config(a, ["b.c.z"]) == {"b.c.z": 3}
|
28
|
+
assert select_config(a, ["b.c.z", "x"]) == {"b.c.z": 3, "x": 1}
|
29
|
+
assert select_config(a, ["b.c.z", "b.y"]) == {"b.c.z": 3, "b.y": 2}
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
3
|
+
import pytest
|
4
|
+
from mlflow.artifacts import download_artifacts
|
5
|
+
from mlflow.entities import Run
|
6
|
+
|
7
|
+
from hydraflow.run_collection import RunCollection
|
8
|
+
|
9
|
+
|
10
|
+
@pytest.fixture(scope="module")
|
11
|
+
def rc(collect):
|
12
|
+
args = ["-m", "name=a,b", "height=3"]
|
13
|
+
return collect("config/config.py", args)
|
14
|
+
|
15
|
+
|
16
|
+
@pytest.fixture(scope="module")
|
17
|
+
def run(rc: RunCollection):
|
18
|
+
return rc.first()
|
19
|
+
|
20
|
+
|
21
|
+
def test_select_overrides(run: Run):
|
22
|
+
path = download_artifacts(f"{run.info.artifact_uri}/overrides.txt")
|
23
|
+
assert Path(path).read_text() == "{'name': 'x', 'height': 2.0}"
|
@@ -4,10 +4,10 @@ from dataclasses import dataclass, field
|
|
4
4
|
import pytest
|
5
5
|
from omegaconf import ListConfig, OmegaConf
|
6
6
|
|
7
|
+
from hydraflow.config import _is_param, collect_params, iter_params
|
7
8
|
|
8
|
-
def test_is_param_with_simple_values():
|
9
|
-
from hydraflow.config import _is_param
|
10
9
|
|
10
|
+
def test_is_param_with_simple_values():
|
11
11
|
assert _is_param(1) is True
|
12
12
|
assert _is_param("string") is True
|
13
13
|
assert _is_param(3.14) is True
|
@@ -15,69 +15,49 @@ def test_is_param_with_simple_values():
|
|
15
15
|
|
16
16
|
|
17
17
|
def test_is_param_with_dictconfig_containing_simple_values():
|
18
|
-
from hydraflow.config import _is_param
|
19
|
-
|
20
18
|
dict_conf = OmegaConf.create({"a": 1, "b": "string", "c": 3.14, "d": True})
|
21
19
|
assert _is_param(dict_conf) is False
|
22
20
|
|
23
21
|
|
24
22
|
def test_is_param_with_listconfig_containing_simple_values():
|
25
|
-
from hydraflow.config import _is_param
|
26
|
-
|
27
23
|
list_conf = OmegaConf.create([1, "string", 3.14, True])
|
28
24
|
assert _is_param(list_conf) is True
|
29
25
|
|
30
26
|
|
31
27
|
def test_is_param_with_listconfig_containing_nested_dictconfig():
|
32
|
-
from hydraflow.config import _is_param
|
33
|
-
|
34
28
|
nested_list_conf = OmegaConf.create([1, {"a": 1}, 3.14])
|
35
29
|
assert _is_param(nested_list_conf) is False
|
36
30
|
|
37
31
|
|
38
32
|
def test_is_param_with_listconfig_containing_nested_listconfig():
|
39
|
-
from hydraflow.config import _is_param
|
40
|
-
|
41
33
|
nested_list_conf_2 = OmegaConf.create([1, [2, 3], 3.14])
|
42
34
|
assert _is_param(nested_list_conf_2) is False
|
43
35
|
|
44
36
|
|
45
37
|
def test_is_param_with_empty_dictconfig():
|
46
|
-
from hydraflow.config import _is_param
|
47
|
-
|
48
38
|
empty_dict_conf = OmegaConf.create({})
|
49
39
|
assert _is_param(empty_dict_conf) is False
|
50
40
|
|
51
41
|
|
52
42
|
def test_is_param_with_empty_listconfig():
|
53
|
-
from hydraflow.config import _is_param
|
54
|
-
|
55
43
|
empty_list_conf = OmegaConf.create([])
|
56
44
|
assert _is_param(empty_list_conf) is True
|
57
45
|
|
58
46
|
|
59
47
|
def test_is_param_with_none():
|
60
|
-
from hydraflow.config import _is_param
|
61
|
-
|
62
48
|
assert _is_param(None) is True
|
63
49
|
|
64
50
|
|
65
51
|
def test_is_param_with_complex_nested_structure():
|
66
|
-
from hydraflow.config import _is_param
|
67
|
-
|
68
52
|
complex_conf = OmegaConf.create({"a": [1, {"b": 2}], "c": {"d": 3}})
|
69
53
|
assert _is_param(complex_conf) is False
|
70
54
|
|
71
55
|
|
72
56
|
def test_iter_params_with_none():
|
73
|
-
from hydraflow.config import iter_params
|
74
|
-
|
75
57
|
assert not list(iter_params(None))
|
76
58
|
|
77
59
|
|
78
60
|
def test_iter_params():
|
79
|
-
from hydraflow.config import iter_params
|
80
|
-
|
81
61
|
conf = OmegaConf.create({"k": "v", "l": [1, {"a": "1", "b": "2", 3: "c"}]})
|
82
62
|
it = iter_params(conf)
|
83
63
|
assert next(it) == ("k", "v")
|
@@ -88,8 +68,6 @@ def test_iter_params():
|
|
88
68
|
|
89
69
|
|
90
70
|
def test_collect_params():
|
91
|
-
from hydraflow.config import collect_params
|
92
|
-
|
93
71
|
conf = OmegaConf.create({"k": "v", "l": [1, {"a": "1", "b": "2", 3: "c"}]})
|
94
72
|
params = collect_params(conf)
|
95
73
|
assert params == {"k": "v", "l.0": 1, "l.1.a": "1", "l.1.b": "2", "l.1.3": "c"}
|
@@ -131,8 +109,6 @@ def test_config(cfg: Config):
|
|
131
109
|
|
132
110
|
|
133
111
|
def test_iter_params_from_config(cfg):
|
134
|
-
from hydraflow.config import iter_params
|
135
|
-
|
136
112
|
it = iter_params(cfg)
|
137
113
|
assert next(it) == ("size.x", 1)
|
138
114
|
assert next(it) == ("size.y", 2)
|
@@ -142,8 +118,6 @@ def test_iter_params_from_config(cfg):
|
|
142
118
|
|
143
119
|
|
144
120
|
def test_iter_params_with_empty_config():
|
145
|
-
from hydraflow.config import iter_params
|
146
|
-
|
147
121
|
empty_cfg = Config(
|
148
122
|
size=Size(x=0, y=0),
|
149
123
|
db=Db(name="", port=0),
|
@@ -158,8 +132,6 @@ def test_iter_params_with_empty_config():
|
|
158
132
|
|
159
133
|
|
160
134
|
def test_iter_params_with_nested_config():
|
161
|
-
from hydraflow.config import iter_params
|
162
|
-
|
163
135
|
@dataclass
|
164
136
|
class Nested:
|
165
137
|
level1: Config = field(default_factory=Config)
|
@@ -174,8 +146,6 @@ def test_iter_params_with_nested_config():
|
|
174
146
|
|
175
147
|
|
176
148
|
def test_iter_params_with_mixed_types_in_list():
|
177
|
-
from hydraflow.config import iter_params
|
178
|
-
|
179
149
|
@dataclass
|
180
150
|
class MixedStore:
|
181
151
|
items: list = field(default_factory=lambda: ["a", 1, {"key": "value"}])
|
@@ -209,48 +179,14 @@ def test_list_config_str(s):
|
|
209
179
|
|
210
180
|
@pytest.mark.parametrize("x", [{"a": 1}, {"a": [1, 2, 3]}])
|
211
181
|
def test_collect_params_dict(x):
|
212
|
-
from hydraflow.config import collect_params
|
213
|
-
|
214
182
|
assert collect_params(x) == x
|
215
183
|
|
216
184
|
|
217
185
|
def test_collect_params_dict_dot():
|
218
|
-
from hydraflow.config import collect_params
|
219
|
-
|
220
186
|
assert collect_params({"a": {"b": 1}}) == {"a.b": 1}
|
221
187
|
assert collect_params({"a.b": 1}) == {"a.b": 1}
|
222
188
|
|
223
189
|
|
224
190
|
def test_collect_params_list_dot():
|
225
|
-
from hydraflow.config import collect_params
|
226
|
-
|
227
191
|
assert collect_params(["a=1"]) == {"a": "1"}
|
228
192
|
assert collect_params(["a.b=2", "c"]) == {"a.b": "2"}
|
229
|
-
|
230
|
-
|
231
|
-
@dataclass
|
232
|
-
class C:
|
233
|
-
z: int = 3
|
234
|
-
|
235
|
-
|
236
|
-
@dataclass
|
237
|
-
class B:
|
238
|
-
y: int = 2
|
239
|
-
c: C = field(default_factory=C)
|
240
|
-
|
241
|
-
|
242
|
-
@dataclass
|
243
|
-
class A:
|
244
|
-
x: int = 1
|
245
|
-
b: B = field(default_factory=B)
|
246
|
-
|
247
|
-
|
248
|
-
def test_select_config():
|
249
|
-
from hydraflow.config import select_config
|
250
|
-
|
251
|
-
a = A()
|
252
|
-
assert select_config(a, ["x"]) == {"x": 1}
|
253
|
-
assert select_config(a, ["b.y"]) == {"b.y": 2}
|
254
|
-
assert select_config(a, ["b.c.z"]) == {"b.c.z": 3}
|
255
|
-
assert select_config(a, ["b.c.z", "x"]) == {"b.c.z": 3, "x": 1}
|
256
|
-
assert select_config(a, ["b.c.z", "b.y"]) == {"b.c.z": 3, "b.y": 2}
|
@@ -0,0 +1,81 @@
|
|
1
|
+
import importlib
|
2
|
+
import os
|
3
|
+
import subprocess
|
4
|
+
import sys
|
5
|
+
import uuid
|
6
|
+
from pathlib import Path
|
7
|
+
|
8
|
+
import pytest
|
9
|
+
from hydra import compose, initialize
|
10
|
+
from hydra.core.config_store import ConfigStore
|
11
|
+
|
12
|
+
|
13
|
+
@pytest.fixture(scope="module")
|
14
|
+
def experiment_name(tmp_path_factory: pytest.TempPathFactory):
|
15
|
+
cwd = Path.cwd()
|
16
|
+
name = str(uuid.uuid4())
|
17
|
+
|
18
|
+
os.chdir(tmp_path_factory.mktemp(name))
|
19
|
+
|
20
|
+
yield name
|
21
|
+
|
22
|
+
os.chdir(cwd)
|
23
|
+
|
24
|
+
|
25
|
+
@pytest.fixture(scope="module")
|
26
|
+
def run_script(experiment_name: str):
|
27
|
+
parent = Path(__file__).parent
|
28
|
+
|
29
|
+
def run_script(filename: str, args: list[str]):
|
30
|
+
file = parent / filename
|
31
|
+
job_name = f"hydra.job.name={experiment_name}"
|
32
|
+
|
33
|
+
args = [sys.executable, file.as_posix(), *args, job_name]
|
34
|
+
subprocess.run(args, check=False)
|
35
|
+
|
36
|
+
return experiment_name
|
37
|
+
|
38
|
+
return run_script
|
39
|
+
|
40
|
+
|
41
|
+
@pytest.fixture(scope="module")
|
42
|
+
def collect(run_script):
|
43
|
+
from hydraflow.mlflow import search_runs
|
44
|
+
|
45
|
+
def collect(filename: str, args: list[str]):
|
46
|
+
experiment_name = run_script(filename, args)
|
47
|
+
return search_runs(experiment_names=[experiment_name])
|
48
|
+
|
49
|
+
return collect
|
50
|
+
|
51
|
+
|
52
|
+
@pytest.fixture(scope="module")
|
53
|
+
def get_config_class():
|
54
|
+
parent = Path(__file__).parent
|
55
|
+
|
56
|
+
def get_config_class(filename: str):
|
57
|
+
file = parent / filename
|
58
|
+
|
59
|
+
sys.path.insert(0, file.parent.as_posix())
|
60
|
+
module = importlib.import_module(file.stem)
|
61
|
+
del sys.path[0]
|
62
|
+
|
63
|
+
return module.Config
|
64
|
+
|
65
|
+
return get_config_class
|
66
|
+
|
67
|
+
|
68
|
+
@pytest.fixture
|
69
|
+
def get_config(get_config_class):
|
70
|
+
cs = ConfigStore.instance()
|
71
|
+
|
72
|
+
def get_config(filename: str, overrides: list[str] | None = None):
|
73
|
+
cls = get_config_class(filename)
|
74
|
+
|
75
|
+
name = str(uuid.uuid4())
|
76
|
+
cs.store(name=name, node=cls)
|
77
|
+
|
78
|
+
with initialize(version_base=None):
|
79
|
+
return compose(config_name=name, overrides=overrides)
|
80
|
+
|
81
|
+
return get_config
|
File without changes
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import sys
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import hydra
|
8
|
+
from hydra.core.config_store import ConfigStore
|
9
|
+
|
10
|
+
import hydraflow
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class Config:
|
15
|
+
name: str = "a"
|
16
|
+
|
17
|
+
|
18
|
+
cs = ConfigStore.instance()
|
19
|
+
cs.store(name="config", node=Config)
|
20
|
+
|
21
|
+
|
22
|
+
@hydra.main(version_base=None, config_name="config")
|
23
|
+
def app(cfg: Config):
|
24
|
+
hydraflow.set_experiment()
|
25
|
+
|
26
|
+
with hydraflow.start_run(cfg) as run:
|
27
|
+
with hydraflow.chdir_hydra_output():
|
28
|
+
Path("a.txt").write_text("chdir_hydra_output")
|
29
|
+
|
30
|
+
with hydraflow.chdir_artifact(run):
|
31
|
+
Path("b.txt").write_text("chdir_artifact")
|
32
|
+
|
33
|
+
if cfg.name == "b":
|
34
|
+
raise ValueError
|
35
|
+
|
36
|
+
if cfg.name == "c":
|
37
|
+
sys.exit(1)
|
38
|
+
|
39
|
+
with hydraflow.start_run(cfg, run_id=run.info.run_id): # Skip log config
|
40
|
+
pass
|
41
|
+
|
42
|
+
|
43
|
+
if __name__ == "__main__":
|
44
|
+
app()
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import pytest
|
2
|
+
from mlflow.entities import Run
|
3
|
+
|
4
|
+
from hydraflow.run_collection import RunCollection
|
5
|
+
from hydraflow.utils import get_artifact_path, get_hydra_output_dir
|
6
|
+
|
7
|
+
|
8
|
+
@pytest.fixture(scope="module")
|
9
|
+
def rc(collect):
|
10
|
+
args = ["-m", "name=a,b,c"]
|
11
|
+
return collect("context/context.py", args)
|
12
|
+
|
13
|
+
|
14
|
+
@pytest.fixture(scope="module", params=range(3))
|
15
|
+
def run(rc: RunCollection, request: pytest.FixtureRequest):
|
16
|
+
return rc[request.param]
|
17
|
+
|
18
|
+
|
19
|
+
def test_chdir_hydra_output(run: Run):
|
20
|
+
path = get_hydra_output_dir(run)
|
21
|
+
assert (path / "a.txt").read_text() == "chdir_hydra_output"
|
22
|
+
|
23
|
+
|
24
|
+
def test_chdir_artifact(run: Run):
|
25
|
+
path = get_artifact_path(run, "b.txt")
|
26
|
+
assert path.read_text() == "chdir_artifact"
|
27
|
+
|
28
|
+
|
29
|
+
def test_log_run(run: Run):
|
30
|
+
path = get_artifact_path(run, "a.txt")
|
31
|
+
assert path.read_text() == "chdir_hydra_output"
|
File without changes
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass, field
|
4
|
+
|
5
|
+
import hydra
|
6
|
+
from hydra.core.config_store import ConfigStore
|
7
|
+
|
8
|
+
import hydraflow
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class Data:
|
13
|
+
x: list[int] = field(default_factory=lambda: [1, 2, 3])
|
14
|
+
y: list[int] = field(default_factory=lambda: [4, 5, 6])
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class Config:
|
19
|
+
host: str = "localhost"
|
20
|
+
port: int = 3306
|
21
|
+
data: Data = field(default_factory=Data)
|
22
|
+
|
23
|
+
|
24
|
+
cs = ConfigStore.instance()
|
25
|
+
cs.store(name="config", node=Config)
|
26
|
+
|
27
|
+
|
28
|
+
@hydra.main(version_base=None, config_name="config")
|
29
|
+
def app(cfg: Config):
|
30
|
+
hydraflow.set_experiment()
|
31
|
+
|
32
|
+
with hydraflow.start_run(cfg):
|
33
|
+
pass
|
34
|
+
|
35
|
+
|
36
|
+
if __name__ == "__main__":
|
37
|
+
app()
|
@@ -0,0 +1,33 @@
|
|
1
|
+
import pytest
|
2
|
+
from mlflow.entities import Run
|
3
|
+
|
4
|
+
from hydraflow.run_collection import RunCollection
|
5
|
+
|
6
|
+
|
7
|
+
@pytest.fixture(scope="module")
|
8
|
+
def rc(collect):
|
9
|
+
args = ["host=a"]
|
10
|
+
return collect("param/param.py", args)
|
11
|
+
|
12
|
+
|
13
|
+
@pytest.fixture(scope="module")
|
14
|
+
def run(rc: RunCollection):
|
15
|
+
return rc.first()
|
16
|
+
|
17
|
+
|
18
|
+
def test_get_params_str(run: Run):
|
19
|
+
from hydraflow.param import get_params
|
20
|
+
|
21
|
+
assert get_params(run, "host") == ("a",)
|
22
|
+
|
23
|
+
|
24
|
+
def test_get_params_list(run: Run):
|
25
|
+
from hydraflow.param import get_params
|
26
|
+
|
27
|
+
assert get_params(run, ["host"], ["port"]) == ("a", "3306")
|
28
|
+
|
29
|
+
|
30
|
+
def test_get_values(run: Run):
|
31
|
+
from hydraflow.param import get_values
|
32
|
+
|
33
|
+
assert get_values(run, ["host", "port"], [str, int]) == ("a", 3306)
|