hydraflow 0.4.6__tar.gz → 0.5.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.4.6 → hydraflow-0.5.0}/PKG-INFO +1 -16
- {hydraflow-0.4.6 → hydraflow-0.5.0}/README.md +0 -10
- {hydraflow-0.4.6 → hydraflow-0.5.0}/pyproject.toml +2 -12
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/__init__.py +1 -5
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/context.py +1 -106
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/run_collection.py +1 -1
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/run_data.py +3 -3
- {hydraflow-0.4.6/tests/scripts → hydraflow-0.5.0/tests/apps}/app.py +4 -14
- hydraflow-0.5.0/tests/conftest.py +16 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_app.py +8 -14
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_context.py +2 -22
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_log_run.py +4 -8
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_mlflow.py +1 -1
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_param.py +3 -1
- hydraflow-0.5.0/tests/test_run_collection.py +465 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_run_data.py +14 -17
- hydraflow-0.5.0/tests/test_run_info.py +42 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_utils.py +10 -12
- hydraflow-0.4.6/src/hydraflow/asyncio.py +0 -227
- hydraflow-0.4.6/src/hydraflow/progress.py +0 -184
- hydraflow-0.4.6/tests/integ/app.py +0 -62
- hydraflow-0.4.6/tests/scripts/__init__.py +0 -0
- hydraflow-0.4.6/tests/scripts/progress.py +0 -65
- hydraflow-0.4.6/tests/scripts/watch.py +0 -9
- hydraflow-0.4.6/tests/test_asyncio.py +0 -221
- hydraflow-0.4.6/tests/test_progress.py +0 -12
- hydraflow-0.4.6/tests/test_run_collection.py +0 -487
- hydraflow-0.4.6/tests/test_run_info.py +0 -48
- hydraflow-0.4.6/tests/test_version.py +0 -5
- hydraflow-0.4.6/tests/test_watch.py +0 -28
- {hydraflow-0.4.6 → hydraflow-0.5.0}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/.gitattributes +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/.gitignore +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/LICENSE +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/apps/quickstart.py +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/mkdocs.yml +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/config.py +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/mlflow.py +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/param.py +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/py.typed +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/run_info.py +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/src/hydraflow/utils.py +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/__init__.py +0 -0
- {hydraflow-0.4.6/tests/integ → hydraflow-0.5.0/tests/apps}/__init__.py +0 -0
- {hydraflow-0.4.6 → hydraflow-0.5.0}/tests/test_config.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.5.0
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://github.com/daizutabi/hydraflow
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -37,12 +37,7 @@ Classifier: Programming Language :: Python :: 3.12
|
|
37
37
|
Classifier: Programming Language :: Python :: 3.13
|
38
38
|
Requires-Python: >=3.10
|
39
39
|
Requires-Dist: hydra-core>=1.3
|
40
|
-
Requires-Dist: joblib
|
41
40
|
Requires-Dist: mlflow>=2.15
|
42
|
-
Requires-Dist: polars
|
43
|
-
Requires-Dist: rich
|
44
|
-
Requires-Dist: watchdog
|
45
|
-
Requires-Dist: watchfiles
|
46
41
|
Description-Content-Type: text/markdown
|
47
42
|
|
48
43
|
# Hydraflow
|
@@ -119,16 +114,6 @@ def my_app(cfg: MySQLConfig) -> None:
|
|
119
114
|
with hydraflow.start_run():
|
120
115
|
# Your app code below.
|
121
116
|
|
122
|
-
with hydraflow.watch(callback):
|
123
|
-
# Watch files in the MLflow artifact directory.
|
124
|
-
# You can update metrics or log other artifacts
|
125
|
-
# according to the watched files in your callback
|
126
|
-
# function.
|
127
|
-
pass
|
128
|
-
|
129
|
-
# Your callback function here.
|
130
|
-
def callback(file: Path) -> None:
|
131
|
-
pass
|
132
117
|
|
133
118
|
if __name__ == "__main__":
|
134
119
|
my_app()
|
@@ -72,16 +72,6 @@ def my_app(cfg: MySQLConfig) -> None:
|
|
72
72
|
with hydraflow.start_run():
|
73
73
|
# Your app code below.
|
74
74
|
|
75
|
-
with hydraflow.watch(callback):
|
76
|
-
# Watch files in the MLflow artifact directory.
|
77
|
-
# You can update metrics or log other artifacts
|
78
|
-
# according to the watched files in your callback
|
79
|
-
# function.
|
80
|
-
pass
|
81
|
-
|
82
|
-
# Your callback function here.
|
83
|
-
def callback(file: Path) -> None:
|
84
|
-
pass
|
85
75
|
|
86
76
|
if __name__ == "__main__":
|
87
77
|
my_app()
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.
|
7
|
+
version = "0.5.0"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = { file = "LICENSE" }
|
@@ -19,15 +19,7 @@ classifiers = [
|
|
19
19
|
"Programming Language :: Python :: 3.13",
|
20
20
|
]
|
21
21
|
requires-python = ">=3.10"
|
22
|
-
dependencies = [
|
23
|
-
"hydra-core>=1.3",
|
24
|
-
"joblib",
|
25
|
-
"mlflow>=2.15",
|
26
|
-
"polars",
|
27
|
-
"rich",
|
28
|
-
"watchdog",
|
29
|
-
"watchfiles",
|
30
|
-
]
|
22
|
+
dependencies = ["hydra-core>=1.3", "mlflow>=2.15"]
|
31
23
|
|
32
24
|
[project.urls]
|
33
25
|
Documentation = "https://github.com/daizutabi/hydraflow"
|
@@ -40,7 +32,6 @@ dev-dependencies = [
|
|
40
32
|
"mkapi",
|
41
33
|
"mkdocs-material",
|
42
34
|
"mkdocs>=1.6",
|
43
|
-
"pytest-asyncio",
|
44
35
|
"pytest-clarity",
|
45
36
|
"pytest-cov",
|
46
37
|
"pytest-randomly",
|
@@ -65,7 +56,6 @@ filterwarnings = [
|
|
65
56
|
"ignore:Support for class-based `config` is deprecated",
|
66
57
|
"ignore:Pydantic V1 style",
|
67
58
|
]
|
68
|
-
asyncio_default_fixture_loop_scope = "function"
|
69
59
|
|
70
60
|
[tool.coverage.report]
|
71
61
|
exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
|
@@ -1,9 +1,8 @@
|
|
1
1
|
"""Integrate Hydra and MLflow to manage and track machine learning experiments."""
|
2
2
|
|
3
3
|
from .config import select_config, select_overrides
|
4
|
-
from .context import chdir_artifact, chdir_hydra_output, log_run, start_run
|
4
|
+
from .context import chdir_artifact, chdir_hydra_output, log_run, start_run
|
5
5
|
from .mlflow import list_runs, search_runs, set_experiment
|
6
|
-
from .progress import multi_tasks_progress, parallel_progress
|
7
6
|
from .run_collection import RunCollection
|
8
7
|
from .utils import (
|
9
8
|
get_artifact_dir,
|
@@ -25,13 +24,10 @@ __all__ = [
|
|
25
24
|
"load_config",
|
26
25
|
"load_overrides",
|
27
26
|
"log_run",
|
28
|
-
"multi_tasks_progress",
|
29
|
-
"parallel_progress",
|
30
27
|
"remove_run",
|
31
28
|
"search_runs",
|
32
29
|
"select_config",
|
33
30
|
"select_overrides",
|
34
31
|
"set_experiment",
|
35
32
|
"start_run",
|
36
|
-
"watch",
|
37
33
|
]
|
@@ -4,7 +4,6 @@ from __future__ import annotations
|
|
4
4
|
|
5
5
|
import logging
|
6
6
|
import os
|
7
|
-
import time
|
8
7
|
from contextlib import contextmanager
|
9
8
|
from pathlib import Path
|
10
9
|
from typing import TYPE_CHECKING
|
@@ -12,14 +11,11 @@ from typing import TYPE_CHECKING
|
|
12
11
|
import mlflow
|
13
12
|
import mlflow.artifacts
|
14
13
|
from hydra.core.hydra_config import HydraConfig
|
15
|
-
from watchdog.events import FileModifiedEvent, PatternMatchingEventHandler
|
16
|
-
from watchdog.observers import Observer
|
17
14
|
|
18
15
|
from hydraflow.mlflow import log_params
|
19
|
-
from hydraflow.run_info import get_artifact_dir
|
20
16
|
|
21
17
|
if TYPE_CHECKING:
|
22
|
-
from collections.abc import
|
18
|
+
from collections.abc import Iterator
|
23
19
|
|
24
20
|
from mlflow.entities.run import Run
|
25
21
|
|
@@ -64,14 +60,8 @@ def log_run(
|
|
64
60
|
output_subdir = output_dir / (hc.output_subdir or "")
|
65
61
|
mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
|
66
62
|
|
67
|
-
def log_artifact(path: Path) -> None:
|
68
|
-
local_path = (output_dir / path).as_posix()
|
69
|
-
mlflow.log_artifact(local_path)
|
70
|
-
|
71
63
|
try:
|
72
64
|
yield
|
73
|
-
# with watch(log_artifact, output_dir, ignore_log=False):
|
74
|
-
# yield
|
75
65
|
|
76
66
|
except Exception as e:
|
77
67
|
msg = f"Error during log_run: {e}"
|
@@ -146,101 +136,6 @@ def start_run( # noqa: PLR0913
|
|
146
136
|
yield run
|
147
137
|
|
148
138
|
|
149
|
-
@contextmanager
|
150
|
-
def watch(
|
151
|
-
callback: Callable[[Path], None],
|
152
|
-
dir: Path | str = "", # noqa: A002
|
153
|
-
*,
|
154
|
-
timeout: int = 60,
|
155
|
-
ignore_patterns: list[str] | None = None,
|
156
|
-
ignore_log: bool = True,
|
157
|
-
) -> Iterator[None]:
|
158
|
-
"""Watch the given directory for changes.
|
159
|
-
|
160
|
-
This context manager sets up a file system watcher on the specified directory.
|
161
|
-
When a file modification is detected, the provided function is called with
|
162
|
-
the path of the modified file. The watcher runs for the specified timeout
|
163
|
-
period or until the context is exited.
|
164
|
-
|
165
|
-
Args:
|
166
|
-
callback (Callable[[Path], None]): The function to call when a change is
|
167
|
-
detected. It should accept a single argument of type `Path`,
|
168
|
-
which is the path of the modified file.
|
169
|
-
dir (Path | str): The directory to watch. If not specified,
|
170
|
-
the current MLflow artifact URI is used. Defaults to "".
|
171
|
-
timeout (int): The timeout period in seconds for the watcher
|
172
|
-
to run after the context is exited. Defaults to 60.
|
173
|
-
ignore_patterns (list[str] | None): A list of glob patterns to ignore.
|
174
|
-
Defaults to None.
|
175
|
-
ignore_log (bool): Whether to ignore log files. Defaults to True.
|
176
|
-
|
177
|
-
Yields:
|
178
|
-
None
|
179
|
-
|
180
|
-
Example:
|
181
|
-
```python
|
182
|
-
with watch(log_artifact, "/path/to/dir"):
|
183
|
-
# Perform operations while watching the directory for changes
|
184
|
-
pass
|
185
|
-
```
|
186
|
-
|
187
|
-
"""
|
188
|
-
dir = dir or get_artifact_dir() # noqa: A001
|
189
|
-
if isinstance(dir, Path):
|
190
|
-
dir = dir.as_posix() # noqa: A001
|
191
|
-
|
192
|
-
handler = Handler(callback, ignore_patterns=ignore_patterns, ignore_log=ignore_log)
|
193
|
-
observer = Observer()
|
194
|
-
observer.schedule(handler, dir, recursive=True)
|
195
|
-
observer.start()
|
196
|
-
|
197
|
-
try:
|
198
|
-
yield
|
199
|
-
|
200
|
-
except Exception as e:
|
201
|
-
msg = f"Error during watch: {e}"
|
202
|
-
log.exception(msg)
|
203
|
-
raise
|
204
|
-
|
205
|
-
finally:
|
206
|
-
elapsed = 0
|
207
|
-
while not observer.event_queue.empty():
|
208
|
-
time.sleep(0.2)
|
209
|
-
elapsed += 0.2
|
210
|
-
if elapsed > timeout:
|
211
|
-
break
|
212
|
-
|
213
|
-
observer.stop()
|
214
|
-
observer.join()
|
215
|
-
|
216
|
-
|
217
|
-
class Handler(PatternMatchingEventHandler):
|
218
|
-
"""Monitor file changes and call the given function when a change is detected."""
|
219
|
-
|
220
|
-
def __init__(
|
221
|
-
self,
|
222
|
-
func: Callable[[Path], None],
|
223
|
-
*,
|
224
|
-
ignore_patterns: list[str] | None = None,
|
225
|
-
ignore_log: bool = True,
|
226
|
-
) -> None:
|
227
|
-
self.func = func
|
228
|
-
|
229
|
-
if ignore_log:
|
230
|
-
if ignore_patterns:
|
231
|
-
ignore_patterns.append("*.log")
|
232
|
-
else:
|
233
|
-
ignore_patterns = ["*.log"]
|
234
|
-
|
235
|
-
super().__init__(ignore_patterns=ignore_patterns)
|
236
|
-
|
237
|
-
def on_modified(self, event: FileModifiedEvent) -> None:
|
238
|
-
"""Modify when a file is modified."""
|
239
|
-
file = Path(str(event.src_path))
|
240
|
-
if file.is_file():
|
241
|
-
self.func(file)
|
242
|
-
|
243
|
-
|
244
139
|
@contextmanager
|
245
140
|
def chdir_hydra_output() -> Iterator[Path]:
|
246
141
|
"""Change the current working directory to the hydra output directory.
|
@@ -575,7 +575,7 @@ class RunCollection:
|
|
575
575
|
"""
|
576
576
|
return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
|
577
577
|
|
578
|
-
def
|
578
|
+
def groupby(
|
579
579
|
self,
|
580
580
|
names: str | list[str],
|
581
581
|
) -> dict[str | None | tuple[str | None, ...], RunCollection]:
|
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|
4
4
|
|
5
5
|
from typing import TYPE_CHECKING
|
6
6
|
|
7
|
-
from
|
7
|
+
from pandas import DataFrame
|
8
8
|
|
9
9
|
from hydraflow.config import collect_params
|
10
10
|
|
@@ -33,10 +33,10 @@ class RunCollectionData:
|
|
33
33
|
|
34
34
|
@property
|
35
35
|
def config(self) -> DataFrame:
|
36
|
-
"""Get the runs' configurations as a
|
36
|
+
"""Get the runs' configurations as a DataFrame.
|
37
37
|
|
38
38
|
Returns:
|
39
|
-
A
|
39
|
+
A DataFrame containing the runs' configurations.
|
40
40
|
|
41
41
|
"""
|
42
42
|
return DataFrame(self._runs.map_config(collect_params))
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import time
|
5
4
|
from dataclasses import dataclass, field
|
6
5
|
from pathlib import Path
|
7
6
|
|
@@ -41,8 +40,9 @@ def app(cfg: MySQLConfig):
|
|
41
40
|
assert cfg.get("values") == [1, 2, 3] # type: ignore
|
42
41
|
|
43
42
|
hydraflow.set_experiment(prefix="_", suffix="_")
|
44
|
-
with hydraflow.start_run(cfg) as run:
|
45
|
-
|
43
|
+
with hydraflow.start_run(cfg, synchronous=True) as run:
|
44
|
+
msg = f"START, {cfg.host}, {cfg.port} "
|
45
|
+
log.info(msg)
|
46
46
|
|
47
47
|
artifact_dir = hydraflow.get_artifact_dir()
|
48
48
|
output_dir = hydraflow.get_hydra_output_dir()
|
@@ -52,10 +52,6 @@ def app(cfg: MySQLConfig):
|
|
52
52
|
mlflow.log_text("A " + artifact_dir.as_posix(), "artifact_dir.txt")
|
53
53
|
mlflow.log_text("B " + output_dir.as_posix(), "output_dir.txt")
|
54
54
|
|
55
|
-
# with hydraflow.watch(callback, ignore_patterns=["b.txt"]):
|
56
|
-
# (artifact_dir / "a.txt").write_text("abc")
|
57
|
-
# time.sleep(0.1)
|
58
|
-
|
59
55
|
(artifact_dir / "a.txt").write_text("abc")
|
60
56
|
|
61
57
|
mlflow.log_metric("m", cfg.port + 1, 1)
|
@@ -65,16 +61,10 @@ def app(cfg: MySQLConfig):
|
|
65
61
|
assert hydraflow.get_overrides() == hydraflow.load_overrides(run)
|
66
62
|
|
67
63
|
if cfg.host == "error":
|
68
|
-
raise Exception("error")
|
64
|
+
raise Exception("error") # noqa: TRY002
|
69
65
|
|
70
66
|
log.info("END")
|
71
67
|
|
72
68
|
|
73
|
-
def callback(path: Path):
|
74
|
-
log.info(f"WATCH, {path.as_posix()}")
|
75
|
-
m = len(path.read_text()) # len("abc") == 3
|
76
|
-
# mlflow.log_metric("watch", m, 1, synchronous=True)
|
77
|
-
|
78
|
-
|
79
69
|
if __name__ == "__main__":
|
80
70
|
app()
|
@@ -0,0 +1,16 @@
|
|
1
|
+
import os
|
2
|
+
import uuid
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
import mlflow
|
6
|
+
import pytest
|
7
|
+
|
8
|
+
|
9
|
+
@pytest.fixture(scope="module")
|
10
|
+
def experiment_name(tmp_path_factory: pytest.TempPathFactory):
|
11
|
+
cwd = Path.cwd()
|
12
|
+
name = str(uuid.uuid4())
|
13
|
+
os.chdir(tmp_path_factory.mktemp(name))
|
14
|
+
mlflow.set_experiment(name)
|
15
|
+
yield name
|
16
|
+
os.chdir(cwd)
|
@@ -1,20 +1,14 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
import os
|
4
2
|
import subprocess
|
5
3
|
import sys
|
6
4
|
from pathlib import Path
|
7
|
-
from typing import TYPE_CHECKING
|
8
5
|
|
9
6
|
import mlflow
|
10
7
|
import pytest
|
11
8
|
from mlflow.entities import RunStatus
|
12
|
-
from omegaconf import OmegaConf
|
13
|
-
|
14
|
-
if TYPE_CHECKING:
|
15
|
-
from omegaconf import DictConfig
|
9
|
+
from omegaconf import DictConfig, OmegaConf
|
16
10
|
|
17
|
-
|
11
|
+
from hydraflow.run_collection import RunCollection
|
18
12
|
|
19
13
|
|
20
14
|
@pytest.fixture(scope="module")
|
@@ -23,7 +17,7 @@ def rc(tmp_path_factory: pytest.TempPathFactory):
|
|
23
17
|
|
24
18
|
cwd = Path.cwd()
|
25
19
|
|
26
|
-
file = Path("tests/
|
20
|
+
file = Path("tests/apps/app.py").absolute()
|
27
21
|
os.chdir(tmp_path_factory.mktemp("test_app"))
|
28
22
|
|
29
23
|
args = [sys.executable, file.as_posix(), "-m"]
|
@@ -117,7 +111,7 @@ def test_app_data_config(rc: RunCollection):
|
|
117
111
|
def test_app_data_config_list(rc: RunCollection):
|
118
112
|
config = rc.data.config
|
119
113
|
values = config["values"].to_list()
|
120
|
-
assert str(config
|
114
|
+
assert str(config["values"].dtypes) == "object"
|
121
115
|
for x in values:
|
122
116
|
assert isinstance(x, list)
|
123
117
|
assert x == [1, 2, 3]
|
@@ -159,8 +153,8 @@ def test_app_map_config(rc: RunCollection):
|
|
159
153
|
assert ports == [2, 3, 2, 3]
|
160
154
|
|
161
155
|
|
162
|
-
def
|
163
|
-
grouped = rc.
|
156
|
+
def test_app_groupby(rc: RunCollection):
|
157
|
+
grouped = rc.groupby("host")
|
164
158
|
assert len(grouped) == 2
|
165
159
|
assert grouped["x"].data.params["port"] == ["1", "2"]
|
166
160
|
assert grouped["x"].data.params["host"] == ["x", "x"]
|
@@ -170,8 +164,8 @@ def test_app_group_by(rc: RunCollection):
|
|
170
164
|
assert grouped["y"].data.params["values"] == ["[1, 2, 3]", "[1, 2, 3]"]
|
171
165
|
|
172
166
|
|
173
|
-
def
|
174
|
-
grouped = rc.
|
167
|
+
def test_app_groupby_list(rc: RunCollection):
|
168
|
+
grouped = rc.groupby(["host"])
|
175
169
|
assert len(grouped) == 2
|
176
170
|
assert ("x",) in grouped
|
177
171
|
assert ("y",) in grouped
|
@@ -1,16 +1,15 @@
|
|
1
|
-
import time
|
2
1
|
from pathlib import Path
|
3
2
|
from unittest.mock import MagicMock, patch
|
4
3
|
|
5
4
|
import mlflow
|
6
5
|
import pytest
|
7
6
|
|
8
|
-
from hydraflow.context import log_run, start_run
|
7
|
+
from hydraflow.context import log_run, start_run
|
9
8
|
from hydraflow.run_collection import RunCollection
|
10
9
|
|
11
10
|
|
12
11
|
@pytest.fixture
|
13
|
-
def runs(monkeypatch, tmp_path):
|
12
|
+
def runs(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
14
13
|
from hydraflow.mlflow import list_runs
|
15
14
|
|
16
15
|
monkeypatch.chdir(tmp_path)
|
@@ -67,22 +66,3 @@ def test_log_run_error_handling(tmp_path: Path):
|
|
67
66
|
with pytest.raises(Exception, match="Test exception"):
|
68
67
|
with log_run(config):
|
69
68
|
pass
|
70
|
-
|
71
|
-
|
72
|
-
def test_watch_context_manager(tmp_path: Path):
|
73
|
-
test_dir = tmp_path / "test_watch"
|
74
|
-
test_dir.mkdir(parents=True, exist_ok=True)
|
75
|
-
test_file = test_dir / "test_file.txt"
|
76
|
-
|
77
|
-
called = []
|
78
|
-
|
79
|
-
def mock_func(path: Path):
|
80
|
-
assert path == test_file
|
81
|
-
called.append(path)
|
82
|
-
|
83
|
-
with watch(mock_func, test_dir):
|
84
|
-
test_file.write_text("new content")
|
85
|
-
time.sleep(1)
|
86
|
-
|
87
|
-
assert len(called) == 1
|
88
|
-
assert called[0] == test_file
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
import os
|
4
2
|
import subprocess
|
5
3
|
import sys
|
@@ -13,7 +11,7 @@ from mlflow.entities.run import Run
|
|
13
11
|
|
14
12
|
@pytest.fixture(scope="module")
|
15
13
|
def runs(tmp_path_factory: pytest.TempPathFactory):
|
16
|
-
file = Path("tests/
|
14
|
+
file = Path("tests/apps/app.py").absolute()
|
17
15
|
|
18
16
|
cwd = Path.cwd()
|
19
17
|
os.chdir(tmp_path_factory.mktemp("test_log_run"))
|
@@ -32,11 +30,9 @@ def runs(tmp_path_factory: pytest.TempPathFactory):
|
|
32
30
|
os.chdir(cwd)
|
33
31
|
|
34
32
|
|
35
|
-
@pytest.fixture(params=range(4))
|
36
|
-
def run(runs, request):
|
37
|
-
|
38
|
-
assert isinstance(run, Run)
|
39
|
-
return run
|
33
|
+
@pytest.fixture(scope="module", params=range(4))
|
34
|
+
def run(runs: list[Run], request: pytest.FixtureRequest):
|
35
|
+
return runs[request.param]
|
40
36
|
|
41
37
|
|
42
38
|
@pytest.fixture
|
@@ -1,9 +1,11 @@
|
|
1
|
+
from pathlib import Path
|
2
|
+
|
1
3
|
import mlflow
|
2
4
|
import pytest
|
3
5
|
|
4
6
|
|
5
7
|
@pytest.fixture
|
6
|
-
def param(monkeypatch, tmp_path):
|
8
|
+
def param(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
7
9
|
def param(value):
|
8
10
|
monkeypatch.chdir(tmp_path)
|
9
11
|
mlflow.set_experiment("test_param")
|