hydraflow 0.2.14__tar.gz → 0.2.16__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.2.14 → hydraflow-0.2.16}/.gitignore +2 -1
- {hydraflow-0.2.14 → hydraflow-0.2.16}/PKG-INFO +3 -10
- {hydraflow-0.2.14 → hydraflow-0.2.16}/pyproject.toml +27 -13
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/asyncio.py +39 -19
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/config.py +3 -3
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/context.py +28 -20
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/info.py +1 -1
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/mlflow.py +4 -2
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/progress.py +2 -2
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/run_collection.py +26 -9
- hydraflow-0.2.16/tests/__init__.py +0 -0
- hydraflow-0.2.16/tests/scripts/__init__.py +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_app.py +8 -6
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_asyncio.py +9 -3
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_config.py +5 -1
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_log_run.py +1 -2
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_progress.py +2 -2
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_run_collection.py +11 -7
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_watch.py +1 -1
- {hydraflow-0.2.14 → hydraflow-0.2.16}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/.gitattributes +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/LICENSE +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/README.md +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/mkdocs.yml +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/__init__.py +0 -0
- /hydraflow-0.2.14/tests/scripts/__init__.py → /hydraflow-0.2.16/src/hydraflow/py.typed +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/scripts/app.py +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/scripts/progress.py +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/scripts/watch.py +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_context.py +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_info.py +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_mlflow.py +0 -0
- {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_version.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.16
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://github.com/daizutabi/hydraflow
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -14,19 +14,12 @@ Classifier: Programming Language :: Python :: 3.10
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.11
|
15
15
|
Classifier: Programming Language :: Python :: 3.12
|
16
16
|
Requires-Python: >=3.10
|
17
|
-
Requires-Dist: hydra-core
|
17
|
+
Requires-Dist: hydra-core>=1.3
|
18
18
|
Requires-Dist: joblib
|
19
|
-
Requires-Dist: mlflow
|
19
|
+
Requires-Dist: mlflow>=2.15
|
20
20
|
Requires-Dist: rich
|
21
|
-
Requires-Dist: setuptools
|
22
21
|
Requires-Dist: watchdog
|
23
22
|
Requires-Dist: watchfiles
|
24
|
-
Provides-Extra: dev
|
25
|
-
Requires-Dist: pytest-asyncio; extra == 'dev'
|
26
|
-
Requires-Dist: pytest-clarity; extra == 'dev'
|
27
|
-
Requires-Dist: pytest-cov; extra == 'dev'
|
28
|
-
Requires-Dist: pytest-randomly; extra == 'dev'
|
29
|
-
Requires-Dist: pytest-xdist; extra == 'dev'
|
30
23
|
Description-Content-Type: text/markdown
|
31
24
|
|
32
25
|
# Hydraflow
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.16"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = "MIT"
|
@@ -18,29 +18,29 @@ classifiers = [
|
|
18
18
|
]
|
19
19
|
requires-python = ">=3.10"
|
20
20
|
dependencies = [
|
21
|
-
"hydra-core
|
21
|
+
"hydra-core>=1.3",
|
22
22
|
"joblib",
|
23
|
-
"mlflow
|
23
|
+
"mlflow>=2.15",
|
24
24
|
"rich",
|
25
|
-
"setuptools",
|
26
25
|
"watchdog",
|
27
26
|
"watchfiles",
|
28
27
|
]
|
29
28
|
|
30
|
-
[project.
|
31
|
-
|
29
|
+
[project.urls]
|
30
|
+
Documentation = "https://github.com/daizutabi/hydraflow"
|
31
|
+
Source = "https://github.com/daizutabi/hydraflow"
|
32
|
+
Issues = "https://github.com/daizutabi/hydraflow/issues"
|
33
|
+
|
34
|
+
[tool.uv]
|
35
|
+
dev-dependencies = [
|
32
36
|
"pytest-asyncio",
|
33
37
|
"pytest-clarity",
|
34
38
|
"pytest-cov",
|
35
39
|
"pytest-randomly",
|
36
40
|
"pytest-xdist",
|
41
|
+
"ruff",
|
37
42
|
]
|
38
43
|
|
39
|
-
[project.urls]
|
40
|
-
Documentation = "https://github.com/daizutabi/hydraflow"
|
41
|
-
Source = "https://github.com/daizutabi/hydraflow"
|
42
|
-
Issues = "https://github.com/daizutabi/hydraflow/issues"
|
43
|
-
|
44
44
|
[tool.hatch.build.targets.sdist]
|
45
45
|
exclude = ["/.github", "/docs"]
|
46
46
|
|
@@ -62,7 +62,21 @@ exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
|
|
62
62
|
|
63
63
|
[tool.ruff]
|
64
64
|
line-length = 88
|
65
|
-
target-version = "
|
65
|
+
target-version = "py310"
|
66
66
|
|
67
67
|
[tool.ruff.lint]
|
68
|
-
|
68
|
+
select = ["ALL"]
|
69
|
+
ignore = [
|
70
|
+
"ANN003",
|
71
|
+
"ANN401",
|
72
|
+
"ARG002",
|
73
|
+
"B904",
|
74
|
+
"D",
|
75
|
+
"EM101",
|
76
|
+
"PGH003",
|
77
|
+
"TRY003",
|
78
|
+
]
|
79
|
+
exclude = ["tests/scripts/*.py"]
|
80
|
+
|
81
|
+
[tool.ruff.lint.per-file-ignores]
|
82
|
+
"tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM117", "SLF"]
|
@@ -42,7 +42,10 @@ async def execute_command(
|
|
42
42
|
"""
|
43
43
|
try:
|
44
44
|
process = await asyncio.create_subprocess_exec(
|
45
|
-
program,
|
45
|
+
program,
|
46
|
+
*args,
|
47
|
+
stdout=PIPE,
|
48
|
+
stderr=PIPE,
|
46
49
|
)
|
47
50
|
await asyncio.gather(
|
48
51
|
process_stream(process.stdout, stdout),
|
@@ -51,7 +54,8 @@ async def execute_command(
|
|
51
54
|
returncode = await process.wait()
|
52
55
|
|
53
56
|
except Exception as e:
|
54
|
-
|
57
|
+
msg = f"Error running command: {e}"
|
58
|
+
logger.exception(msg)
|
55
59
|
returncode = 1
|
56
60
|
|
57
61
|
finally:
|
@@ -103,11 +107,15 @@ async def monitor_file_changes(
|
|
103
107
|
str_paths = [str(path) for path in paths]
|
104
108
|
try:
|
105
109
|
async for changes in watchfiles.awatch(
|
106
|
-
*str_paths,
|
110
|
+
*str_paths,
|
111
|
+
stop_event=stop_event,
|
112
|
+
**awatch_kwargs,
|
107
113
|
):
|
108
114
|
callback(changes)
|
109
115
|
except Exception as e:
|
110
|
-
|
116
|
+
msg = f"Error watching files: {e}"
|
117
|
+
logger.exception(msg)
|
118
|
+
raise
|
111
119
|
|
112
120
|
|
113
121
|
async def run_and_monitor(
|
@@ -134,13 +142,16 @@ async def run_and_monitor(
|
|
134
142
|
stop_event = asyncio.Event()
|
135
143
|
run_task = asyncio.create_task(
|
136
144
|
execute_command(
|
137
|
-
program,
|
138
|
-
|
145
|
+
program,
|
146
|
+
*args,
|
147
|
+
stop_event=stop_event,
|
148
|
+
stdout=stdout,
|
149
|
+
stderr=stderr,
|
150
|
+
),
|
139
151
|
)
|
140
152
|
if watch and paths:
|
141
|
-
|
142
|
-
|
143
|
-
)
|
153
|
+
coro = monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
|
154
|
+
monitor_task = asyncio.create_task(coro)
|
144
155
|
else:
|
145
156
|
monitor_task = None
|
146
157
|
|
@@ -151,7 +162,10 @@ async def run_and_monitor(
|
|
151
162
|
await run_task
|
152
163
|
|
153
164
|
except Exception as e:
|
154
|
-
|
165
|
+
msg = f"Error in run_and_monitor: {e}"
|
166
|
+
logger.exception(msg)
|
167
|
+
raise
|
168
|
+
|
155
169
|
finally:
|
156
170
|
stop_event.set()
|
157
171
|
await run_task
|
@@ -173,18 +187,24 @@ def run(
|
|
173
187
|
"""
|
174
188
|
Run a command synchronously and optionally watch for file changes.
|
175
189
|
|
176
|
-
This function is a synchronous wrapper around the asynchronous
|
177
|
-
It runs a specified command and optionally
|
178
|
-
|
190
|
+
This function is a synchronous wrapper around the asynchronous
|
191
|
+
`run_and_monitor` function. It runs a specified command and optionally
|
192
|
+
monitors specified paths for file changes, invoking the provided callbacks for
|
193
|
+
standard output, standard error, and file changes.
|
179
194
|
|
180
195
|
Args:
|
181
196
|
program (str): The program to run.
|
182
197
|
*args (str): Arguments for the program.
|
183
|
-
stdout (Callable[[str], None] | None): Callback for handling standard
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
198
|
+
stdout (Callable[[str], None] | None): Callback for handling standard
|
199
|
+
output lines.
|
200
|
+
stderr (Callable[[str], None] | None): Callback for handling standard
|
201
|
+
error lines.
|
202
|
+
watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for
|
203
|
+
handling file changes.
|
204
|
+
paths (list[str | Path] | None): List of paths to monitor for file
|
205
|
+
changes.
|
206
|
+
**awatch_kwargs: Additional keyword arguments to pass to
|
207
|
+
`watchfiles.awatch`.
|
188
208
|
|
189
209
|
Returns:
|
190
210
|
int: The return code of the process.
|
@@ -201,5 +221,5 @@ def run(
|
|
201
221
|
watch=watch,
|
202
222
|
paths=paths,
|
203
223
|
**awatch_kwargs,
|
204
|
-
)
|
224
|
+
),
|
205
225
|
)
|
@@ -33,7 +33,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
33
33
|
if config is None:
|
34
34
|
return
|
35
35
|
|
36
|
-
if not isinstance(config,
|
36
|
+
if not isinstance(config, DictConfig | ListConfig):
|
37
37
|
config = OmegaConf.create(config) # type: ignore
|
38
38
|
|
39
39
|
yield from _iter_params(config, prefix)
|
@@ -62,8 +62,8 @@ def _is_param(value: object) -> bool:
|
|
62
62
|
if isinstance(value, DictConfig):
|
63
63
|
return False
|
64
64
|
|
65
|
-
if isinstance(value, ListConfig):
|
66
|
-
if any(isinstance(v,
|
65
|
+
if isinstance(value, ListConfig): # noqa: SIM102
|
66
|
+
if any(isinstance(v, DictConfig | ListConfig) for v in value):
|
67
67
|
return False
|
68
68
|
|
69
69
|
return True
|
@@ -75,7 +75,8 @@ def log_run(
|
|
75
75
|
yield
|
76
76
|
|
77
77
|
except Exception as e:
|
78
|
-
|
78
|
+
msg = f"Error during log_run: {e}"
|
79
|
+
log.exception(msg)
|
79
80
|
raise
|
80
81
|
|
81
82
|
finally:
|
@@ -84,7 +85,7 @@ def log_run(
|
|
84
85
|
|
85
86
|
|
86
87
|
@contextmanager
|
87
|
-
def start_run(
|
88
|
+
def start_run( # noqa: PLR0913
|
88
89
|
config: object,
|
89
90
|
*,
|
90
91
|
run_id: str | None = None,
|
@@ -112,8 +113,10 @@ def start_run(
|
|
112
113
|
parent_run_id (str | None): The parent run ID. Defaults to None.
|
113
114
|
tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
|
114
115
|
description (str | None): A description of the run. Defaults to None.
|
115
|
-
log_system_metrics (bool | None): Whether to log system metrics.
|
116
|
-
|
116
|
+
log_system_metrics (bool | None): Whether to log system metrics.
|
117
|
+
Defaults to None.
|
118
|
+
synchronous (bool | None): Whether to log parameters synchronously.
|
119
|
+
Defaults to None.
|
117
120
|
|
118
121
|
Yields:
|
119
122
|
Run: An MLflow Run object representing the started run.
|
@@ -128,24 +131,27 @@ def start_run(
|
|
128
131
|
- `log_run`: A context manager to log parameters and manage the MLflow
|
129
132
|
run context.
|
130
133
|
"""
|
131
|
-
with
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
134
|
+
with (
|
135
|
+
mlflow.start_run(
|
136
|
+
run_id=run_id,
|
137
|
+
experiment_id=experiment_id,
|
138
|
+
run_name=run_name,
|
139
|
+
nested=nested,
|
140
|
+
parent_run_id=parent_run_id,
|
141
|
+
tags=tags,
|
142
|
+
description=description,
|
143
|
+
log_system_metrics=log_system_metrics,
|
144
|
+
) as run,
|
145
|
+
log_run(config, synchronous=synchronous),
|
146
|
+
):
|
147
|
+
yield run
|
143
148
|
|
144
149
|
|
145
150
|
@contextmanager
|
146
151
|
def watch(
|
147
152
|
callback: Callable[[Path], None],
|
148
|
-
dir: Path | str = "",
|
153
|
+
dir: Path | str = "", # noqa: A002
|
154
|
+
*,
|
149
155
|
timeout: int = 60,
|
150
156
|
ignore_patterns: list[str] | None = None,
|
151
157
|
ignore_log: bool = True,
|
@@ -178,9 +184,9 @@ def watch(
|
|
178
184
|
pass
|
179
185
|
```
|
180
186
|
"""
|
181
|
-
dir = dir or get_artifact_dir()
|
187
|
+
dir = dir or get_artifact_dir() # noqa: A001
|
182
188
|
if isinstance(dir, Path):
|
183
|
-
dir = dir.as_posix()
|
189
|
+
dir = dir.as_posix() # noqa: A001
|
184
190
|
|
185
191
|
handler = Handler(callback, ignore_patterns=ignore_patterns, ignore_log=ignore_log)
|
186
192
|
observer = Observer()
|
@@ -191,7 +197,8 @@ def watch(
|
|
191
197
|
yield
|
192
198
|
|
193
199
|
except Exception as e:
|
194
|
-
|
200
|
+
msg = f"Error during watch: {e}"
|
201
|
+
log.exception(msg)
|
195
202
|
raise
|
196
203
|
|
197
204
|
finally:
|
@@ -210,6 +217,7 @@ class Handler(PatternMatchingEventHandler):
|
|
210
217
|
def __init__(
|
211
218
|
self,
|
212
219
|
func: Callable[[Path], None],
|
220
|
+
*,
|
213
221
|
ignore_patterns: list[str] | None = None,
|
214
222
|
ignore_log: bool = True,
|
215
223
|
) -> None:
|
@@ -81,7 +81,8 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
|
|
81
81
|
mlflow.log_param(key, value, synchronous=synchronous)
|
82
82
|
|
83
83
|
|
84
|
-
def search_runs(
|
84
|
+
def search_runs( # noqa: PLR0913
|
85
|
+
*,
|
85
86
|
experiment_ids: list[str] | None = None,
|
86
87
|
filter_string: str = "",
|
87
88
|
run_view_type: int = ViewType.ACTIVE_ONLY,
|
@@ -148,7 +149,8 @@ def search_runs(
|
|
148
149
|
|
149
150
|
|
150
151
|
def list_runs(
|
151
|
-
experiment_names: str | list[str] | None = None,
|
152
|
+
experiment_names: str | list[str] | None = None,
|
153
|
+
n_jobs: int = 0,
|
152
154
|
) -> RunCollection:
|
153
155
|
"""
|
154
156
|
List all runs for the specified experiments.
|
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
|
|
31
31
|
|
32
32
|
# https://github.com/jonghwanhyeon/joblib-progress/blob/main/joblib_progress/__init__.py
|
33
33
|
@contextmanager
|
34
|
-
def JoblibProgress(
|
34
|
+
def JoblibProgress( # noqa: N802
|
35
35
|
*columns: ProgressColumn | str,
|
36
36
|
description: str | None = None,
|
37
37
|
total: int | None = None,
|
@@ -68,7 +68,7 @@ def JoblibProgress(
|
|
68
68
|
task_id = progress.add_task(description, total=total)
|
69
69
|
print_progress = joblib.parallel.Parallel.print_progress
|
70
70
|
|
71
|
-
def update_progress(self: joblib.parallel.Parallel):
|
71
|
+
def update_progress(self: joblib.parallel.Parallel) -> None:
|
72
72
|
progress.update(task_id, completed=self.n_completed_tasks, refresh=True)
|
73
73
|
return print_progress(self)
|
74
74
|
|
@@ -23,8 +23,6 @@ from dataclasses import dataclass, field
|
|
23
23
|
from itertools import chain
|
24
24
|
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
25
25
|
|
26
|
-
from mlflow.entities.run import Run
|
27
|
-
|
28
26
|
from hydraflow.config import iter_params
|
29
27
|
from hydraflow.info import RunCollectionInfo
|
30
28
|
|
@@ -33,6 +31,7 @@ if TYPE_CHECKING:
|
|
33
31
|
from pathlib import Path
|
34
32
|
from typing import Any
|
35
33
|
|
34
|
+
from mlflow.entities.run import Run
|
36
35
|
from omegaconf import DictConfig
|
37
36
|
|
38
37
|
|
@@ -60,7 +59,7 @@ class RunCollection:
|
|
60
59
|
_info: RunCollectionInfo = field(init=False)
|
61
60
|
"""An instance of `RunCollectionInfo`."""
|
62
61
|
|
63
|
-
def __post_init__(self):
|
62
|
+
def __post_init__(self) -> None:
|
64
63
|
self._info = RunCollectionInfo(self)
|
65
64
|
|
66
65
|
def __repr__(self) -> str:
|
@@ -89,7 +88,7 @@ class RunCollection:
|
|
89
88
|
|
90
89
|
@classmethod
|
91
90
|
def from_list(cls, runs: list[Run]) -> RunCollection:
|
92
|
-
"""Create a
|
91
|
+
"""Create a `RunCollection` instance from a list of MLflow `Run` instances."""
|
93
92
|
|
94
93
|
return cls(runs)
|
95
94
|
|
@@ -120,6 +119,7 @@ class RunCollection:
|
|
120
119
|
def sort(
|
121
120
|
self,
|
122
121
|
key: Callable[[Run], Any] | None = None,
|
122
|
+
*,
|
123
123
|
reverse: bool = False,
|
124
124
|
) -> None:
|
125
125
|
self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
|
@@ -393,7 +393,7 @@ class RunCollection:
|
|
393
393
|
param_names = set()
|
394
394
|
|
395
395
|
for run in self:
|
396
|
-
for param in run.data.params
|
396
|
+
for param in run.data.params:
|
397
397
|
param_names.add(param)
|
398
398
|
|
399
399
|
return list(param_names)
|
@@ -537,10 +537,11 @@ class RunCollection:
|
|
537
537
|
Results obtained by applying the function to each artifact directory
|
538
538
|
in the collection.
|
539
539
|
"""
|
540
|
-
return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir)
|
540
|
+
return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
|
541
541
|
|
542
542
|
def group_by(
|
543
|
-
self,
|
543
|
+
self,
|
544
|
+
*names: str | list[str],
|
544
545
|
) -> dict[tuple[str | None, ...], RunCollection]:
|
545
546
|
"""
|
546
547
|
Group runs by specified parameter names.
|
@@ -595,13 +596,19 @@ def _param_matches(run: Run, key: str, value: Any) -> bool:
|
|
595
596
|
if isinstance(value, list) and value:
|
596
597
|
return type(value[0])(param) in value
|
597
598
|
|
598
|
-
if isinstance(value, tuple) and len(value) == 2:
|
599
|
+
if isinstance(value, tuple) and len(value) == 2: # noqa: PLR2004
|
599
600
|
return value[0] <= type(value[0])(param) < value[1]
|
600
601
|
|
601
602
|
return type(value)(param) == value
|
602
603
|
|
603
604
|
|
604
|
-
def filter_runs(
|
605
|
+
def filter_runs(
|
606
|
+
runs: list[Run],
|
607
|
+
config: object | None = None,
|
608
|
+
*,
|
609
|
+
status: str | list[str] | None = None,
|
610
|
+
**kwargs,
|
611
|
+
) -> list[Run]:
|
605
612
|
"""
|
606
613
|
Filter the runs based on the provided configuration.
|
607
614
|
|
@@ -623,6 +630,7 @@ def filter_runs(runs: list[Run], config: object | None = None, **kwargs) -> list
|
|
623
630
|
config (object | None): The configuration object to filter the runs.
|
624
631
|
This can be any object that provides key-value pairs through the
|
625
632
|
`iter_params` function.
|
633
|
+
status (str | list[str] | None): The status of the runs to filter.
|
626
634
|
**kwargs: Additional key-value pairs to filter the runs.
|
627
635
|
|
628
636
|
Returns:
|
@@ -634,6 +642,15 @@ def filter_runs(runs: list[Run], config: object | None = None, **kwargs) -> list
|
|
634
642
|
if len(runs) == 0:
|
635
643
|
return []
|
636
644
|
|
645
|
+
if isinstance(status, str) and status.startswith("!"):
|
646
|
+
status = status[1:].lower()
|
647
|
+
return [run for run in runs if run.info.status.lower() != status]
|
648
|
+
|
649
|
+
if status:
|
650
|
+
status = [status] if isinstance(status, str) else status
|
651
|
+
status = [s.lower() for s in status]
|
652
|
+
return [run for run in runs if run.info.status.lower() in status]
|
653
|
+
|
637
654
|
return runs
|
638
655
|
|
639
656
|
|
File without changes
|
File without changes
|
@@ -3,12 +3,15 @@ from __future__ import annotations
|
|
3
3
|
import subprocess
|
4
4
|
import sys
|
5
5
|
from pathlib import Path
|
6
|
+
from typing import TYPE_CHECKING
|
6
7
|
|
7
8
|
import mlflow
|
8
9
|
import pytest
|
9
|
-
from omegaconf import DictConfig
|
10
10
|
|
11
|
-
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from omegaconf import DictConfig
|
13
|
+
|
14
|
+
from hydraflow.run_collection import RunCollection
|
12
15
|
|
13
16
|
|
14
17
|
@pytest.fixture
|
@@ -32,7 +35,7 @@ def test_list_runs_all(rc: RunCollection):
|
|
32
35
|
rc_ = list_runs([])
|
33
36
|
assert len(rc) == len(rc_)
|
34
37
|
|
35
|
-
for a, b in zip(rc, rc_):
|
38
|
+
for a, b in zip(rc, rc_, strict=False):
|
36
39
|
assert a.info.run_id == b.info.run_id
|
37
40
|
assert a.info.start_time == b.info.start_time
|
38
41
|
assert a.info.status == b.info.status
|
@@ -46,7 +49,7 @@ def test_list_runs_parallel(rc: RunCollection, n_jobs: int):
|
|
46
49
|
rc_ = list_runs("_info_", n_jobs=n_jobs)
|
47
50
|
assert len(rc) == len(rc_)
|
48
51
|
|
49
|
-
for a, b in zip(rc, rc_):
|
52
|
+
for a, b in zip(rc, rc_, strict=False):
|
50
53
|
assert a.info.run_id == b.info.run_id
|
51
54
|
assert a.info.start_time == b.info.start_time
|
52
55
|
assert a.info.status == b.info.status
|
@@ -61,7 +64,7 @@ def test_list_runs_parallel_active(rc: RunCollection, n_jobs: int):
|
|
61
64
|
rc_ = list_runs(n_jobs=n_jobs)
|
62
65
|
assert len(rc) == len(rc_)
|
63
66
|
|
64
|
-
for a, b in zip(rc, rc_):
|
67
|
+
for a, b in zip(rc, rc_, strict=False):
|
65
68
|
assert a.info.run_id == b.info.run_id
|
66
69
|
assert a.info.start_time == b.info.start_time
|
67
70
|
assert a.info.status == b.info.status
|
@@ -98,7 +101,6 @@ def test_app_info_config(rc: RunCollection):
|
|
98
101
|
|
99
102
|
def test_app_info_artifact_uri(rc: RunCollection):
|
100
103
|
uris = rc.info.artifact_uri
|
101
|
-
print(uris)
|
102
104
|
assert all(uri.startswith("file://") for uri in uris) # type: ignore
|
103
105
|
assert all(uri.endswith("/artifacts") for uri in uris) # type: ignore
|
104
106
|
assert all("mlruns" in uri for uri in uris) # type: ignore
|
@@ -5,10 +5,14 @@ import sys
|
|
5
5
|
from pathlib import Path
|
6
6
|
from threading import Thread
|
7
7
|
from time import sleep
|
8
|
-
from typing import
|
8
|
+
from typing import TYPE_CHECKING
|
9
9
|
|
10
10
|
import pytest
|
11
|
-
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from collections.abc import Callable
|
14
|
+
|
15
|
+
from watchfiles import Change
|
12
16
|
|
13
17
|
|
14
18
|
def sleep_write(path: Path):
|
@@ -72,7 +76,9 @@ async def test_monitor_file_changes(tmp_path: Path, write_soon: Callable[[Path],
|
|
72
76
|
changes_detected.extend(changes)
|
73
77
|
|
74
78
|
write_soon(tmp_path / "test.txt")
|
75
|
-
monitor_task = asyncio.create_task(
|
79
|
+
monitor_task = asyncio.create_task(
|
80
|
+
monitor_file_changes([tmp_path], callback, stop_event),
|
81
|
+
)
|
76
82
|
|
77
83
|
await asyncio.sleep(1)
|
78
84
|
stop_event.set()
|
@@ -135,7 +135,11 @@ def test_iter_params_from_config(cfg):
|
|
135
135
|
def test_iter_params_with_empty_config():
|
136
136
|
from hydraflow.config import iter_params
|
137
137
|
|
138
|
-
empty_cfg = Config(
|
138
|
+
empty_cfg = Config(
|
139
|
+
size=Size(x=0, y=0),
|
140
|
+
db=Db(name="", port=0),
|
141
|
+
store=Store(items=[]),
|
142
|
+
)
|
139
143
|
it = iter_params(empty_cfg)
|
140
144
|
assert next(it) == ("size.x", 0)
|
141
145
|
assert next(it) == ("size.y", 0)
|
@@ -46,8 +46,7 @@ def test_output(run_id: str):
|
|
46
46
|
|
47
47
|
def read_log(run_id: str, path: str) -> str:
|
48
48
|
path = download_artifacts(run_id=run_id, artifact_path=path)
|
49
|
-
|
50
|
-
return text
|
49
|
+
return Path(path).read_text()
|
51
50
|
|
52
51
|
|
53
52
|
def test_load_config(run: Run):
|
@@ -5,8 +5,8 @@ import pytest
|
|
5
5
|
|
6
6
|
|
7
7
|
@pytest.mark.skipif(
|
8
|
-
sys.platform == "win32", reason="'cp932' codec can't encode character '\\u2807'"
|
8
|
+
sys.platform == "win32", reason="'cp932' codec can't encode character '\\u2807'",
|
9
9
|
)
|
10
10
|
def test_progress_bar():
|
11
|
-
cp = run([sys.executable, "tests/scripts/progress.py"])
|
11
|
+
cp = run([sys.executable, "tests/scripts/progress.py"], check=False)
|
12
12
|
assert cp.returncode == 0
|
@@ -90,6 +90,16 @@ def test_filter_invalid_param(run_list: list[Run]):
|
|
90
90
|
assert len(x) == 6
|
91
91
|
|
92
92
|
|
93
|
+
def test_filter_status(run_list: list[Run]):
|
94
|
+
from hydraflow.run_collection import filter_runs
|
95
|
+
|
96
|
+
assert not filter_runs(run_list, status="RUNNING")
|
97
|
+
assert filter_runs(run_list, status="finished") == run_list
|
98
|
+
assert filter_runs(run_list, status=["finished", "running"]) == run_list
|
99
|
+
assert filter_runs(run_list, status="!RUNNING") == run_list
|
100
|
+
assert not filter_runs(run_list, status="!finished")
|
101
|
+
|
102
|
+
|
93
103
|
def test_get_params(run_list: list[Run]):
|
94
104
|
from hydraflow.run_collection import get_params
|
95
105
|
|
@@ -162,8 +172,6 @@ def test_runs_filter(runs: RunCollection):
|
|
162
172
|
|
163
173
|
|
164
174
|
def test_runs_get(runs: RunCollection):
|
165
|
-
from hydraflow.run_collection import Run
|
166
|
-
|
167
175
|
run = runs.get({"p": 4})
|
168
176
|
assert isinstance(run, Run)
|
169
177
|
run = runs.get(p=2)
|
@@ -195,8 +203,6 @@ def test_runs_get_params_dict(runs: RunCollection):
|
|
195
203
|
|
196
204
|
|
197
205
|
def test_runs_find(runs: RunCollection):
|
198
|
-
from hydraflow.run_collection import Run
|
199
|
-
|
200
206
|
run = runs.find({"r": 0})
|
201
207
|
assert isinstance(run, Run)
|
202
208
|
assert run.data.params["p"] == "0"
|
@@ -216,8 +222,6 @@ def test_runs_try_find_none(runs: RunCollection):
|
|
216
222
|
|
217
223
|
|
218
224
|
def test_runs_find_last(runs: RunCollection):
|
219
|
-
from hydraflow.run_collection import Run
|
220
|
-
|
221
225
|
run = runs.find_last({"r": 0})
|
222
226
|
assert isinstance(run, Run)
|
223
227
|
assert run.data.params["p"] == "3"
|
@@ -303,7 +307,7 @@ def test_run_collection_map_run_id_kwargs(runs: RunCollection):
|
|
303
307
|
def test_run_collection_map_uri(runs: RunCollection):
|
304
308
|
results = list(runs.map_uri(lambda uri: uri))
|
305
309
|
assert len(results) == len(runs._runs)
|
306
|
-
assert all(isinstance(uri,
|
310
|
+
assert all(isinstance(uri, str | type(None)) for uri in results)
|
307
311
|
|
308
312
|
|
309
313
|
def test_run_collection_map_dir(runs: RunCollection):
|
@@ -8,7 +8,7 @@ import pytest
|
|
8
8
|
|
9
9
|
|
10
10
|
@pytest.mark.parametrize("dir", [".", Path])
|
11
|
-
def test_watch(dir, monkeypatch, tmp_path):
|
11
|
+
def test_watch(dir, monkeypatch, tmp_path): # noqa: A002
|
12
12
|
from hydraflow.context import watch
|
13
13
|
|
14
14
|
file = Path("tests/scripts/watch.py").absolute()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|