hydraflow 0.5.2__tar.gz → 0.5.4__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.5.2 → hydraflow-0.5.4}/.devcontainer/postCreate.sh +1 -4
- {hydraflow-0.5.2 → hydraflow-0.5.4}/PKG-INFO +2 -2
- {hydraflow-0.5.2 → hydraflow-0.5.4}/pyproject.toml +10 -34
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/config.py +0 -3
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/context.py +0 -3
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/mlflow.py +1 -5
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/param.py +4 -5
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/run_collection.py +24 -31
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/run_data.py +0 -1
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/utils.py +1 -6
- hydraflow-0.5.2/tests/config/test_hydra.py → hydraflow-0.5.4/tests/config/test_overrides.py +3 -1
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/conftest.py +1 -1
- hydraflow-0.5.4/tests/context/preemption.py +49 -0
- hydraflow-0.5.4/tests/context/rerun.py +44 -0
- hydraflow-0.5.2/tests/context/test_hydra.py → hydraflow-0.5.4/tests/context/test_context.py +2 -0
- hydraflow-0.5.4/tests/context/test_preemption.py +41 -0
- hydraflow-0.5.4/tests/context/test_rerun.py +31 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/param/test_param.py +10 -0
- hydraflow-0.5.2/tests/param/test_hydra.py → hydraflow-0.5.4/tests/param/test_params.py +3 -1
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/run/test_collection.py +2 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/run/test_data.py +2 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/run/test_filter.py +2 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/run/test_info.py +2 -0
- hydraflow-0.5.2/tests/run/test_hydra.py → hydraflow-0.5.4/tests/run/test_run.py +2 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/test_mlflow.py +17 -14
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/utils/test_run.py +4 -0
- hydraflow-0.5.2/tests/utils/test_hydra.py → hydraflow-0.5.4/tests/utils/test_utils.py +7 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/.gitattributes +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/.gitignore +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/LICENSE +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/README.md +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/apps/quickstart.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/mkdocs.yml +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/__init__.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/py.typed +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/src/hydraflow/run_info.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/__init__.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/config/__init__.py +0 -0
- /hydraflow-0.5.2/tests/config/config.py → /hydraflow-0.5.4/tests/config/overrides.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/config/test_config.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/config/test_params.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/context/__init__.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/context/context.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/param/__init__.py +0 -0
- /hydraflow-0.5.2/tests/param/param.py → /hydraflow-0.5.4/tests/param/params.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/run/__init__.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/run/filter.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/run/run.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/utils/__init__.py +0 -0
- {hydraflow-0.5.2 → hydraflow-0.5.4}/tests/utils/utils.py +0 -0
@@ -5,7 +5,4 @@ mkdir -p ~/.config
|
|
5
5
|
cp .devcontainer/starship.toml ~/.config
|
6
6
|
|
7
7
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
8
|
-
echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc
|
9
|
-
uv tool install ruff@latest
|
10
|
-
uv python install
|
11
|
-
uv sync -U
|
8
|
+
echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc
|
@@ -1,8 +1,8 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.5.
|
3
|
+
Version: 0.5.4
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
|
-
Project-URL: Documentation, https://github.
|
5
|
+
Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
7
7
|
Project-URL: Issues, https://github.com/daizutabi/hydraflow/issues
|
8
8
|
Author-email: daizutabi <daizutabi@gmail.com>
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.5.
|
7
|
+
version = "0.5.4"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = { file = "LICENSE" }
|
@@ -22,7 +22,7 @@ requires-python = ">=3.10"
|
|
22
22
|
dependencies = ["hydra-core>=1.3", "mlflow>=2.15"]
|
23
23
|
|
24
24
|
[project.urls]
|
25
|
-
Documentation = "https://github.
|
25
|
+
Documentation = "https://daizutabi.github.io/hydraflow/"
|
26
26
|
Source = "https://github.com/daizutabi/hydraflow"
|
27
27
|
Issues = "https://github.com/daizutabi/hydraflow/issues"
|
28
28
|
|
@@ -33,6 +33,9 @@ dev-dependencies = [
|
|
33
33
|
"mkdocs-material",
|
34
34
|
"mkdocs>=1.6",
|
35
35
|
"pytest-cov",
|
36
|
+
"pytest-order",
|
37
|
+
"pytest-randomly",
|
38
|
+
"pytest-xdist",
|
36
39
|
]
|
37
40
|
|
38
41
|
[tool.hatch.build.targets.sdist]
|
@@ -43,11 +46,11 @@ packages = ["src/hydraflow"]
|
|
43
46
|
|
44
47
|
[tool.pytest.ini_options]
|
45
48
|
addopts = [
|
46
|
-
"--doctest-modules",
|
47
49
|
"--cov=hydraflow",
|
48
50
|
"--cov-report=lcov:lcov.info",
|
51
|
+
"-n8",
|
52
|
+
"--dist=loadgroup",
|
49
53
|
]
|
50
|
-
doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"]
|
51
54
|
filterwarnings = [
|
52
55
|
"ignore:pkg_resources is deprecated:DeprecationWarning",
|
53
56
|
"ignore:Support for class-based `config` is deprecated",
|
@@ -65,35 +68,8 @@ target-version = "py310"
|
|
65
68
|
[tool.ruff.lint]
|
66
69
|
select = ["ALL"]
|
67
70
|
unfixable = ["F401"]
|
68
|
-
ignore = [
|
69
|
-
"A005",
|
70
|
-
"ANN003",
|
71
|
-
"ANN401",
|
72
|
-
"ARG002",
|
73
|
-
"B904",
|
74
|
-
"D105",
|
75
|
-
"D107",
|
76
|
-
"D203",
|
77
|
-
"D213",
|
78
|
-
"EM101",
|
79
|
-
"PGH003",
|
80
|
-
"TRY003",
|
81
|
-
]
|
82
|
-
exclude = ["tests/scripts/*.py"]
|
71
|
+
ignore = ["A005", "ANN003", "ANN401", "B904", "D", "EM101", "PGH003", "TRY003"]
|
83
72
|
|
84
73
|
[tool.ruff.lint.per-file-ignores]
|
85
|
-
"tests/*" = [
|
86
|
-
|
87
|
-
"ANN",
|
88
|
-
"ARG",
|
89
|
-
"D",
|
90
|
-
"FBT",
|
91
|
-
"PD",
|
92
|
-
"PLR",
|
93
|
-
"PT",
|
94
|
-
"S",
|
95
|
-
"SIM117",
|
96
|
-
"TID",
|
97
|
-
"SLF",
|
98
|
-
]
|
99
|
-
"apps/*.py" = ["INP", "D", "G", "T"]
|
74
|
+
"tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM108", "SLF"]
|
75
|
+
"apps/*.py" = ["D", "G", "INP"]
|
@@ -22,7 +22,6 @@ def collect_params(config: object) -> dict[str, Any]:
|
|
22
22
|
|
23
23
|
Returns:
|
24
24
|
dict[str, Any]: A dictionary of collected parameters.
|
25
|
-
|
26
25
|
"""
|
27
26
|
return dict(iter_params(config))
|
28
27
|
|
@@ -41,7 +40,6 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
|
41
40
|
|
42
41
|
Yields:
|
43
42
|
Key-value pairs representing the parameters in the configuration object.
|
44
|
-
|
45
43
|
"""
|
46
44
|
if config is None:
|
47
45
|
return
|
@@ -115,7 +113,6 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
|
|
115
113
|
|
116
114
|
Returns:
|
117
115
|
DictConfig: A new configuration object containing only the selected parameters.
|
118
|
-
|
119
116
|
"""
|
120
117
|
if not isinstance(config, DictConfig):
|
121
118
|
config = OmegaConf.structured(config)
|
@@ -48,7 +48,6 @@ def log_run(
|
|
48
48
|
# Perform operations within the MLflow run context
|
49
49
|
pass
|
50
50
|
```
|
51
|
-
|
52
51
|
"""
|
53
52
|
if config:
|
54
53
|
log_params(config, synchronous=synchronous)
|
@@ -118,7 +117,6 @@ def start_run( # noqa: PLR0913
|
|
118
117
|
- `mlflow.start_run`: The MLflow function to start a run directly.
|
119
118
|
- `log_run`: A context manager to log parameters and manage the MLflow
|
120
119
|
run context.
|
121
|
-
|
122
120
|
"""
|
123
121
|
with (
|
124
122
|
mlflow.start_run(
|
@@ -169,7 +167,6 @@ def chdir_artifact(
|
|
169
167
|
Args:
|
170
168
|
run (Run): The run to get the artifact directory from.
|
171
169
|
artifact_path (str | None): The artifact path.
|
172
|
-
|
173
170
|
"""
|
174
171
|
curdir = Path.cwd()
|
175
172
|
path = mlflow.artifacts.download_artifacts(
|
@@ -54,7 +54,6 @@ def set_experiment(
|
|
54
54
|
Returns:
|
55
55
|
Experiment: An instance of `mlflow.entities.Experiment` representing
|
56
56
|
the new active experiment.
|
57
|
-
|
58
57
|
"""
|
59
58
|
if uri is not None:
|
60
59
|
mlflow.set_tracking_uri(uri)
|
@@ -78,7 +77,6 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
|
|
78
77
|
config (object): The configuration object to log the parameters from.
|
79
78
|
synchronous (bool | None): Whether to log the parameters synchronously.
|
80
79
|
Defaults to None.
|
81
|
-
|
82
80
|
"""
|
83
81
|
for key, value in iter_params(config):
|
84
82
|
mlflow.log_param(key, value, synchronous=synchronous)
|
@@ -135,7 +133,6 @@ def search_runs( # noqa: PLR0913
|
|
135
133
|
|
136
134
|
Returns:
|
137
135
|
A `RunCollection` object containing the search results.
|
138
|
-
|
139
136
|
"""
|
140
137
|
runs = mlflow.search_runs(
|
141
138
|
experiment_ids=experiment_ids,
|
@@ -180,7 +177,6 @@ def list_runs(
|
|
180
177
|
Returns:
|
181
178
|
RunCollection: A `RunCollection` instance containing the runs for the
|
182
179
|
specified experiments.
|
183
|
-
|
184
180
|
"""
|
185
181
|
rc = _list_runs(experiment_names, n_jobs)
|
186
182
|
if status is None:
|
@@ -214,7 +210,7 @@ def _list_runs(
|
|
214
210
|
loc = experiment.artifact_location
|
215
211
|
|
216
212
|
if isinstance(loc, str):
|
217
|
-
if loc.startswith("file
|
213
|
+
if loc.startswith("file:"):
|
218
214
|
path = Path(mlflow.artifacts.download_artifacts(loc))
|
219
215
|
elif Path(loc).is_dir():
|
220
216
|
path = Path(loc)
|
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
|
18
18
|
from mlflow.entities import Run
|
19
19
|
|
20
20
|
|
21
|
-
def match(param: str, value: Any) -> bool:
|
21
|
+
def match(param: str, value: Any) -> bool: # noqa: PLR0911
|
22
22
|
"""Check if the string matches the specified value.
|
23
23
|
|
24
24
|
Args:
|
@@ -28,8 +28,10 @@ def match(param: str, value: Any) -> bool:
|
|
28
28
|
Returns:
|
29
29
|
True if the parameter matches the specified value,
|
30
30
|
False otherwise.
|
31
|
-
|
32
31
|
"""
|
32
|
+
if callable(value):
|
33
|
+
return value(param)
|
34
|
+
|
33
35
|
if any(value is x for x in [None, True, False]):
|
34
36
|
return param == str(value)
|
35
37
|
|
@@ -92,7 +94,6 @@ def to_value(param: str | None, type_: type) -> Any:
|
|
92
94
|
|
93
95
|
Returns:
|
94
96
|
The converted value.
|
95
|
-
|
96
97
|
"""
|
97
98
|
if param is None or param == "None":
|
98
99
|
return None
|
@@ -128,7 +129,6 @@ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
|
|
128
129
|
Returns:
|
129
130
|
tuple[str | None, ...]: A tuple containing the values of the specified
|
130
131
|
parameters in the order they were provided.
|
131
|
-
|
132
132
|
"""
|
133
133
|
names_ = []
|
134
134
|
for name in names:
|
@@ -155,7 +155,6 @@ def get_values(run: Run, names: list[str], types: list[type]) -> tuple[Any, ...]
|
|
155
155
|
Returns:
|
156
156
|
tuple[Any, ...]: A tuple containing the values of the specified
|
157
157
|
parameters in the order they were provided.
|
158
|
-
|
159
158
|
"""
|
160
159
|
params = get_params(run, names)
|
161
160
|
it = zip(params, types, strict=True)
|
@@ -106,7 +106,6 @@ class RunCollection:
|
|
106
106
|
|
107
107
|
Returns:
|
108
108
|
A new `RunCollection` instance with the runs from both collections.
|
109
|
-
|
110
109
|
"""
|
111
110
|
return self.__class__(self._runs + other._runs)
|
112
111
|
|
@@ -119,7 +118,6 @@ class RunCollection:
|
|
119
118
|
Returns:
|
120
119
|
A new `RunCollection` instance with the runs that are in this collection
|
121
120
|
but not in the other.
|
122
|
-
|
123
121
|
"""
|
124
122
|
runs = [run for run in self._runs if run not in other._runs] # noqa: SLF001
|
125
123
|
return self.__class__(runs)
|
@@ -152,7 +150,6 @@ class RunCollection:
|
|
152
150
|
Returns:
|
153
151
|
A new `RunCollection` instance containing the first n runs if n is
|
154
152
|
positive, or the last n runs if n is negative.
|
155
|
-
|
156
153
|
"""
|
157
154
|
if n < 0:
|
158
155
|
return self.__class__(self._runs[n:])
|
@@ -167,7 +164,6 @@ class RunCollection:
|
|
167
164
|
|
168
165
|
Raises:
|
169
166
|
ValueError: If the collection does not contain exactly one run.
|
170
|
-
|
171
167
|
"""
|
172
168
|
if len(self._runs) != 1:
|
173
169
|
raise ValueError("The collection does not contain exactly one run.")
|
@@ -180,7 +176,6 @@ class RunCollection:
|
|
180
176
|
Returns:
|
181
177
|
The only `Run` instance in the collection, or None if the collection
|
182
178
|
does not contain exactly one run.
|
183
|
-
|
184
179
|
"""
|
185
180
|
return self._runs[0] if len(self._runs) == 1 else None
|
186
181
|
|
@@ -192,7 +187,6 @@ class RunCollection:
|
|
192
187
|
|
193
188
|
Raises:
|
194
189
|
ValueError: If the collection is empty.
|
195
|
-
|
196
190
|
"""
|
197
191
|
if not self._runs:
|
198
192
|
raise ValueError("The collection is empty.")
|
@@ -205,7 +199,6 @@ class RunCollection:
|
|
205
199
|
Returns:
|
206
200
|
The first `Run` instance in the collection, or None if the collection
|
207
201
|
is empty.
|
208
|
-
|
209
202
|
"""
|
210
203
|
return self._runs[0] if self._runs else None
|
211
204
|
|
@@ -217,7 +210,6 @@ class RunCollection:
|
|
217
210
|
|
218
211
|
Raises:
|
219
212
|
ValueError: If the collection is empty.
|
220
|
-
|
221
213
|
"""
|
222
214
|
if not self._runs:
|
223
215
|
raise ValueError("The collection is empty.")
|
@@ -230,11 +222,18 @@ class RunCollection:
|
|
230
222
|
Returns:
|
231
223
|
The last `Run` instance in the collection, or None if the collection
|
232
224
|
is empty.
|
233
|
-
|
234
225
|
"""
|
235
226
|
return self._runs[-1] if self._runs else None
|
236
227
|
|
237
|
-
def filter(
|
228
|
+
def filter(
|
229
|
+
self,
|
230
|
+
config: object | None = None,
|
231
|
+
*,
|
232
|
+
override: bool = False,
|
233
|
+
select: list[str] | None = None,
|
234
|
+
status: str | list[str] | int | list[int] | None = None,
|
235
|
+
**kwargs,
|
236
|
+
) -> RunCollection:
|
238
237
|
"""Filter the `Run` instances based on the provided configuration.
|
239
238
|
|
240
239
|
This method filters the runs in the collection according to the
|
@@ -254,13 +253,26 @@ class RunCollection:
|
|
254
253
|
config (object | None): The configuration object to filter the runs.
|
255
254
|
This can be any object that provides key-value pairs through
|
256
255
|
the `iter_params` function.
|
256
|
+
override (bool): If True, override the configuration object with the
|
257
|
+
provided key-value pairs.
|
258
|
+
select (list[str] | None): The list of parameters to select.
|
259
|
+
status (str | list[str] | int | list[int] | None): The status of the
|
260
|
+
runs to filter.
|
257
261
|
**kwargs: Additional key-value pairs to filter the runs.
|
258
262
|
|
259
263
|
Returns:
|
260
264
|
A new `RunCollection` object containing the filtered runs.
|
261
|
-
|
262
265
|
"""
|
263
|
-
return RunCollection(
|
266
|
+
return RunCollection(
|
267
|
+
filter_runs(
|
268
|
+
self._runs,
|
269
|
+
config,
|
270
|
+
override=override,
|
271
|
+
select=select,
|
272
|
+
status=status,
|
273
|
+
**kwargs,
|
274
|
+
),
|
275
|
+
)
|
264
276
|
|
265
277
|
def find(self, config: object | None = None, **kwargs) -> Run:
|
266
278
|
"""Find the first `Run` instance based on the provided configuration.
|
@@ -282,7 +294,6 @@ class RunCollection:
|
|
282
294
|
|
283
295
|
See Also:
|
284
296
|
`filter`: Perform the actual filtering logic.
|
285
|
-
|
286
297
|
"""
|
287
298
|
try:
|
288
299
|
return self.filter(config, **kwargs).first()
|
@@ -307,7 +318,6 @@ class RunCollection:
|
|
307
318
|
|
308
319
|
See Also:
|
309
320
|
`filter`: Perform the actual filtering logic.
|
310
|
-
|
311
321
|
"""
|
312
322
|
return self.filter(config, **kwargs).try_first()
|
313
323
|
|
@@ -331,7 +341,6 @@ class RunCollection:
|
|
331
341
|
|
332
342
|
See Also:
|
333
343
|
`filter`: Perform the actual filtering logic.
|
334
|
-
|
335
344
|
"""
|
336
345
|
try:
|
337
346
|
return self.filter(config, **kwargs).last()
|
@@ -356,7 +365,6 @@ class RunCollection:
|
|
356
365
|
|
357
366
|
See Also:
|
358
367
|
`filter`: Perform the actual filtering logic.
|
359
|
-
|
360
368
|
"""
|
361
369
|
return self.filter(config, **kwargs).try_last()
|
362
370
|
|
@@ -381,7 +389,6 @@ class RunCollection:
|
|
381
389
|
|
382
390
|
See Also:
|
383
391
|
`filter`: Perform the actual filtering logic.
|
384
|
-
|
385
392
|
"""
|
386
393
|
try:
|
387
394
|
return self.filter(config, **kwargs).one()
|
@@ -410,7 +417,6 @@ class RunCollection:
|
|
410
417
|
|
411
418
|
See Also:
|
412
419
|
`filter`: Perform the actual filtering logic.
|
413
|
-
|
414
420
|
"""
|
415
421
|
return self.filter(config, **kwargs).try_one()
|
416
422
|
|
@@ -423,7 +429,6 @@ class RunCollection:
|
|
423
429
|
|
424
430
|
Returns:
|
425
431
|
A list of unique parameter names.
|
426
|
-
|
427
432
|
"""
|
428
433
|
param_names = set()
|
429
434
|
|
@@ -448,7 +453,6 @@ class RunCollection:
|
|
448
453
|
Returns:
|
449
454
|
A dictionary where the keys are parameter names and the values are
|
450
455
|
lists of parameter values.
|
451
|
-
|
452
456
|
"""
|
453
457
|
params = {}
|
454
458
|
|
@@ -480,7 +484,6 @@ class RunCollection:
|
|
480
484
|
|
481
485
|
Yields:
|
482
486
|
Results obtained by applying the function to each run in the collection.
|
483
|
-
|
484
487
|
"""
|
485
488
|
return (func(run, *args, **kwargs) for run in self)
|
486
489
|
|
@@ -501,7 +504,6 @@ class RunCollection:
|
|
501
504
|
Yields:
|
502
505
|
Results obtained by applying the function to each run id in the
|
503
506
|
collection.
|
504
|
-
|
505
507
|
"""
|
506
508
|
return (func(run_id, *args, **kwargs) for run_id in self.info.run_id)
|
507
509
|
|
@@ -522,7 +524,6 @@ class RunCollection:
|
|
522
524
|
Yields:
|
523
525
|
Results obtained by applying the function to each run configuration
|
524
526
|
in the collection.
|
525
|
-
|
526
527
|
"""
|
527
528
|
return (func(load_config(run), *args, **kwargs) for run in self)
|
528
529
|
|
@@ -547,7 +548,6 @@ class RunCollection:
|
|
547
548
|
Yields:
|
548
549
|
Results obtained by applying the function to each artifact URI in the
|
549
550
|
collection.
|
550
|
-
|
551
551
|
"""
|
552
552
|
return (func(uri, *args, **kwargs) for uri in self.info.artifact_uri)
|
553
553
|
|
@@ -571,7 +571,6 @@ class RunCollection:
|
|
571
571
|
Yields:
|
572
572
|
Results obtained by applying the function to each artifact directory
|
573
573
|
in the collection.
|
574
|
-
|
575
574
|
"""
|
576
575
|
return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
|
577
576
|
|
@@ -595,7 +594,6 @@ class RunCollection:
|
|
595
594
|
dictionary where the keys are tuples of parameter values and the
|
596
595
|
values are `RunCollection` objects containing the runs that match
|
597
596
|
those parameter values.
|
598
|
-
|
599
597
|
"""
|
600
598
|
grouped_runs: dict[str | None | tuple[str | None, ...], list[Run]] = {}
|
601
599
|
is_list = isinstance(names, list)
|
@@ -622,7 +620,6 @@ class RunCollection:
|
|
622
620
|
key (Callable[[Run], Any] | None): A function that takes a run and returns
|
623
621
|
a value to sort by.
|
624
622
|
reverse (bool): If True, sort in descending order.
|
625
|
-
|
626
623
|
"""
|
627
624
|
self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
|
628
625
|
|
@@ -636,7 +633,6 @@ class RunCollection:
|
|
636
633
|
|
637
634
|
Returns:
|
638
635
|
A list of values for the specified parameters.
|
639
|
-
|
640
636
|
"""
|
641
637
|
is_list = isinstance(names, list)
|
642
638
|
|
@@ -668,7 +664,6 @@ class RunCollection:
|
|
668
664
|
This can be a single parameter name or multiple names provided
|
669
665
|
as separate arguments or as a list.
|
670
666
|
reverse (bool): If True, sort in descending order.
|
671
|
-
|
672
667
|
"""
|
673
668
|
values = self.values(names)
|
674
669
|
index = sorted(range(len(self)), key=lambda i: values[i], reverse=reverse)
|
@@ -726,7 +721,6 @@ def filter_runs(
|
|
726
721
|
|
727
722
|
Returns:
|
728
723
|
A list of runs that match the specified configuration and key-value pairs.
|
729
|
-
|
730
724
|
"""
|
731
725
|
if override:
|
732
726
|
config = select_overrides(config)
|
@@ -757,7 +751,6 @@ def filter_runs_by_status(
|
|
757
751
|
|
758
752
|
Returns:
|
759
753
|
A list of runs that match the specified status.
|
760
|
-
|
761
754
|
"""
|
762
755
|
if isinstance(status, str):
|
763
756
|
if status.startswith("!"):
|
@@ -26,14 +26,13 @@ def get_artifact_dir(run: Run | None = None) -> Path:
|
|
26
26
|
|
27
27
|
Returns:
|
28
28
|
The local path to the directory where the artifacts are downloaded.
|
29
|
-
|
30
29
|
"""
|
31
30
|
uri = mlflow.get_artifact_uri() if run is None else run.info.artifact_uri
|
32
31
|
|
33
32
|
if not isinstance(uri, str):
|
34
33
|
raise NotImplementedError
|
35
34
|
|
36
|
-
if uri.startswith("file
|
35
|
+
if uri.startswith("file:"):
|
37
36
|
return Path(mlflow.artifacts.download_artifacts(uri))
|
38
37
|
|
39
38
|
if Path(uri).is_dir():
|
@@ -53,7 +52,6 @@ def get_artifact_path(run: Run | None, path: str) -> Path:
|
|
53
52
|
|
54
53
|
Returns:
|
55
54
|
The local path to the artifact.
|
56
|
-
|
57
55
|
"""
|
58
56
|
return get_artifact_dir(run) / path
|
59
57
|
|
@@ -76,7 +74,6 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
|
|
76
74
|
Raises:
|
77
75
|
FileNotFoundError: If the Hydra configuration file is not found
|
78
76
|
in the artifacts.
|
79
|
-
|
80
77
|
"""
|
81
78
|
if run is None:
|
82
79
|
hc = HydraConfig.get()
|
@@ -105,7 +102,6 @@ def load_config(run: Run) -> DictConfig:
|
|
105
102
|
Returns:
|
106
103
|
The loaded configuration as a DictConfig object. Returns an empty
|
107
104
|
DictConfig if the configuration file is not found.
|
108
|
-
|
109
105
|
"""
|
110
106
|
path = get_artifact_dir(run) / ".hydra/config.yaml"
|
111
107
|
return OmegaConf.load(path) # type: ignore
|
@@ -130,7 +126,6 @@ def load_overrides(run: Run) -> list[str]:
|
|
130
126
|
Returns:
|
131
127
|
The loaded overrides as a list of strings. Returns an empty list
|
132
128
|
if the overrides file is not found.
|
133
|
-
|
134
129
|
"""
|
135
130
|
path = get_artifact_dir(run) / ".hydra/overrides.yaml"
|
136
131
|
return [str(x) for x in OmegaConf.load(path)]
|
@@ -6,11 +6,13 @@ from mlflow.entities import Run
|
|
6
6
|
|
7
7
|
from hydraflow.run_collection import RunCollection
|
8
8
|
|
9
|
+
pytestmark = pytest.mark.xdist_group(name="group1")
|
10
|
+
|
9
11
|
|
10
12
|
@pytest.fixture(scope="module")
|
11
13
|
def rc(collect):
|
12
14
|
args = ["-m", "name=a,b", "height=3"]
|
13
|
-
return collect("config/
|
15
|
+
return collect("config/overrides.py", args)
|
14
16
|
|
15
17
|
|
16
18
|
@pytest.fixture(scope="module")
|
@@ -0,0 +1,49 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
import hydra
|
7
|
+
from hydra.core.config_store import ConfigStore
|
8
|
+
|
9
|
+
import hydraflow
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from pathlib import Path
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class Config:
|
17
|
+
count: int = 0
|
18
|
+
|
19
|
+
|
20
|
+
cs = ConfigStore.instance()
|
21
|
+
cs.store(name="config", node=Config)
|
22
|
+
|
23
|
+
|
24
|
+
@hydra.main(version_base=None, config_name="config")
|
25
|
+
def app(cfg: Config):
|
26
|
+
hydraflow.set_experiment()
|
27
|
+
|
28
|
+
rc = hydraflow.list_runs()
|
29
|
+
|
30
|
+
if rc.filter(cfg, status="finished", override=True):
|
31
|
+
return
|
32
|
+
|
33
|
+
if run := rc.try_find(cfg, override=True):
|
34
|
+
run_id = run.info.run_id
|
35
|
+
else:
|
36
|
+
run_id = None
|
37
|
+
|
38
|
+
with hydraflow.start_run(cfg, run_id=run_id) as run:
|
39
|
+
log(hydraflow.get_artifact_dir(run))
|
40
|
+
|
41
|
+
|
42
|
+
def log(path: Path):
|
43
|
+
file = path / "a.txt"
|
44
|
+
text = file.read_text() if file.exists() else ""
|
45
|
+
file.write_text(text + "a")
|
46
|
+
|
47
|
+
|
48
|
+
if __name__ == "__main__":
|
49
|
+
app()
|
@@ -0,0 +1,44 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
import hydra
|
7
|
+
from hydra.core.config_store import ConfigStore
|
8
|
+
|
9
|
+
import hydraflow
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from pathlib import Path
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class Config:
|
17
|
+
count: int = 0
|
18
|
+
|
19
|
+
|
20
|
+
cs = ConfigStore.instance()
|
21
|
+
cs.store(name="config", node=Config)
|
22
|
+
|
23
|
+
|
24
|
+
@hydra.main(version_base=None, config_name="config")
|
25
|
+
def app(cfg: Config):
|
26
|
+
hydraflow.set_experiment()
|
27
|
+
|
28
|
+
if run := hydraflow.list_runs().try_find(cfg, override=True):
|
29
|
+
run_id = run.info.run_id
|
30
|
+
else:
|
31
|
+
run_id = None
|
32
|
+
|
33
|
+
with hydraflow.start_run(cfg, run_id=run_id) as run:
|
34
|
+
log(hydraflow.get_artifact_dir(run))
|
35
|
+
|
36
|
+
|
37
|
+
def log(path: Path):
|
38
|
+
file = path / "a.txt"
|
39
|
+
text = file.read_text() if file.exists() else ""
|
40
|
+
file.write_text(text + "a")
|
41
|
+
|
42
|
+
|
43
|
+
if __name__ == "__main__":
|
44
|
+
app()
|
@@ -4,6 +4,8 @@ from mlflow.entities import Run
|
|
4
4
|
from hydraflow.run_collection import RunCollection
|
5
5
|
from hydraflow.utils import get_artifact_path, get_hydra_output_dir
|
6
6
|
|
7
|
+
pytestmark = pytest.mark.xdist_group(name="group2")
|
8
|
+
|
7
9
|
|
8
10
|
@pytest.fixture(scope="module")
|
9
11
|
def rc(collect):
|
@@ -0,0 +1,41 @@
|
|
1
|
+
import pytest
|
2
|
+
from mlflow.entities import Run, RunStatus
|
3
|
+
from mlflow.tracking import MlflowClient
|
4
|
+
|
5
|
+
from hydraflow.run_collection import RunCollection
|
6
|
+
|
7
|
+
pytestmark = pytest.mark.xdist_group(name="group4")
|
8
|
+
|
9
|
+
|
10
|
+
@pytest.fixture(scope="module")
|
11
|
+
def rc(collect):
|
12
|
+
client = MlflowClient()
|
13
|
+
running = RunStatus.to_string(RunStatus.RUNNING)
|
14
|
+
|
15
|
+
filename = "context/preemption.py"
|
16
|
+
args = ["-m", "count=1,2,3"]
|
17
|
+
|
18
|
+
rc = collect(filename, args)
|
19
|
+
client.set_terminated(rc.get(count=2).info.run_id, status=running)
|
20
|
+
client.set_terminated(rc.get(count=3).info.run_id, status=running)
|
21
|
+
rc = collect(filename, args)
|
22
|
+
client.set_terminated(rc.get(count=3).info.run_id, status=running)
|
23
|
+
return collect(filename, args)
|
24
|
+
|
25
|
+
|
26
|
+
def test_rc_len(rc: RunCollection):
|
27
|
+
assert len(rc) == 3
|
28
|
+
|
29
|
+
|
30
|
+
@pytest.fixture(scope="module", params=[1, 2, 3])
|
31
|
+
def run(rc: RunCollection, request: pytest.FixtureRequest):
|
32
|
+
return rc.get(count=request.param)
|
33
|
+
|
34
|
+
|
35
|
+
def test_run_count(run: Run):
|
36
|
+
from hydraflow.utils import get_artifact_path
|
37
|
+
|
38
|
+
count = int(run.data.params["count"])
|
39
|
+
path = get_artifact_path(run, "a.txt")
|
40
|
+
text = path.read_text()
|
41
|
+
assert len(text) == count
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import pytest
|
2
|
+
from mlflow.entities import Run
|
3
|
+
|
4
|
+
from hydraflow.run_collection import RunCollection
|
5
|
+
|
6
|
+
pytestmark = pytest.mark.xdist_group(name="group3")
|
7
|
+
|
8
|
+
|
9
|
+
@pytest.fixture(scope="module")
|
10
|
+
def rc(collect):
|
11
|
+
collect("context/rerun.py", ["-m", "count=1,2,3"])
|
12
|
+
collect("context/rerun.py", ["-m", "count=2,3"])
|
13
|
+
return collect("context/rerun.py", ["-m", "count=3"])
|
14
|
+
|
15
|
+
|
16
|
+
def test_rc_len(rc: RunCollection):
|
17
|
+
assert len(rc) == 3
|
18
|
+
|
19
|
+
|
20
|
+
@pytest.fixture(scope="module", params=[1, 2, 3])
|
21
|
+
def run(rc: RunCollection, request: pytest.FixtureRequest):
|
22
|
+
return rc.get(count=request.param)
|
23
|
+
|
24
|
+
|
25
|
+
def test_run_count(run: Run):
|
26
|
+
from hydraflow.utils import get_artifact_path
|
27
|
+
|
28
|
+
count = int(run.data.params["count"])
|
29
|
+
path = get_artifact_path(run, "a.txt")
|
30
|
+
text = path.read_text()
|
31
|
+
assert len(text) == count
|
@@ -5,6 +5,8 @@ import pytest
|
|
5
5
|
|
6
6
|
from hydraflow.param import match
|
7
7
|
|
8
|
+
pytestmark = pytest.mark.xdist_group(name="group2")
|
9
|
+
|
8
10
|
|
9
11
|
@pytest.fixture
|
10
12
|
def param(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
|
@@ -51,6 +53,14 @@ def test_param(param, x, y):
|
|
51
53
|
assert match(p, x)
|
52
54
|
|
53
55
|
|
56
|
+
@pytest.mark.parametrize(
|
57
|
+
("param", "value"),
|
58
|
+
[("1.0", lambda x: float(x) > 0), ("-1.0", lambda x: float(x) < 0)],
|
59
|
+
)
|
60
|
+
def test_match_callable(param, value):
|
61
|
+
assert match(param, value)
|
62
|
+
|
63
|
+
|
54
64
|
@pytest.mark.parametrize(
|
55
65
|
("param", "value"),
|
56
66
|
[("1.0", 1.0), ("1.0", 1), ("0.0", 0), ("0.0", 0.0)],
|
@@ -3,11 +3,13 @@ from mlflow.entities import Run
|
|
3
3
|
|
4
4
|
from hydraflow.run_collection import RunCollection
|
5
5
|
|
6
|
+
pytestmark = pytest.mark.xdist_group(name="group2")
|
7
|
+
|
6
8
|
|
7
9
|
@pytest.fixture(scope="module")
|
8
10
|
def rc(collect):
|
9
11
|
args = ["host=a"]
|
10
|
-
return collect("param/
|
12
|
+
return collect("param/params.py", args)
|
11
13
|
|
12
14
|
|
13
15
|
@pytest.fixture(scope="module")
|
@@ -7,6 +7,8 @@ from mlflow.entities import Experiment, Run, RunStatus
|
|
7
7
|
from hydraflow.mlflow import list_runs
|
8
8
|
from hydraflow.run_collection import RunCollection, filter_runs
|
9
9
|
|
10
|
+
pytestmark = pytest.mark.xdist_group(name="group3")
|
11
|
+
|
10
12
|
|
11
13
|
@pytest.fixture(scope="module")
|
12
14
|
def experiment(experiment_name: str):
|
@@ -1,21 +1,17 @@
|
|
1
|
-
import sys
|
2
1
|
from pathlib import Path
|
3
2
|
|
4
3
|
import mlflow
|
5
4
|
import pytest
|
6
5
|
from mlflow.entities import Experiment, Run, RunStatus
|
7
6
|
|
8
|
-
pytestmark = pytest.mark.
|
9
|
-
sys.platform == "win32",
|
10
|
-
reason="Windows is not supported",
|
11
|
-
)
|
7
|
+
pytestmark = pytest.mark.xdist_group(name="group0")
|
12
8
|
|
13
9
|
|
14
10
|
@pytest.fixture(scope="module")
|
15
11
|
def experiment(experiment_name: str):
|
16
12
|
from hydraflow.mlflow import log_params, set_experiment
|
17
13
|
|
18
|
-
experiment = set_experiment(uri="
|
14
|
+
experiment = set_experiment(uri="test_mlflow", name="e")
|
19
15
|
|
20
16
|
with mlflow.start_run():
|
21
17
|
log_params({"name": experiment_name})
|
@@ -26,18 +22,21 @@ def experiment(experiment_name: str):
|
|
26
22
|
mlflow.start_run()
|
27
23
|
mlflow.end_run(status=RunStatus.to_string(RunStatus.FAILED))
|
28
24
|
|
29
|
-
|
30
|
-
|
31
|
-
mlflow.set_tracking_uri("")
|
25
|
+
return experiment
|
32
26
|
|
33
27
|
|
34
28
|
def test_set_experiment_uri(experiment: Experiment):
|
35
|
-
assert mlflow.get_tracking_uri() == "
|
29
|
+
assert mlflow.get_tracking_uri() == "test_mlflow"
|
36
30
|
|
37
31
|
|
38
32
|
def test_set_experiment_location(experiment: Experiment):
|
39
|
-
loc =
|
40
|
-
assert
|
33
|
+
loc = experiment.artifact_location
|
34
|
+
assert isinstance(loc, str)
|
35
|
+
if loc.startswith("file:"): # for windows
|
36
|
+
loc = loc[loc.index("C:") :]
|
37
|
+
|
38
|
+
path = Path.cwd() / "test_mlflow" / experiment.experiment_id
|
39
|
+
assert path == Path(loc)
|
41
40
|
|
42
41
|
|
43
42
|
def test_set_experiment_name(experiment: Experiment):
|
@@ -68,8 +67,12 @@ def test_log_params(run: Run, experiment_name):
|
|
68
67
|
def test_get_artifact_dir_from_utils(run: Run, experiment: Experiment):
|
69
68
|
from hydraflow.utils import get_artifact_dir
|
70
69
|
|
71
|
-
|
72
|
-
assert
|
70
|
+
loc = experiment.artifact_location
|
71
|
+
assert isinstance(loc, str)
|
72
|
+
if loc.startswith("file:"): # for windows
|
73
|
+
loc = loc[loc.index("C:") :]
|
74
|
+
|
75
|
+
assert get_artifact_dir(run) == Path(loc) / run.info.run_id / "artifacts"
|
73
76
|
|
74
77
|
|
75
78
|
@pytest.mark.parametrize(
|
@@ -4,6 +4,8 @@ from mlflow.entities import Experiment, Run
|
|
4
4
|
|
5
5
|
from hydraflow.run_collection import RunCollection
|
6
6
|
|
7
|
+
pytestmark = pytest.mark.xdist_group(name="group5")
|
8
|
+
|
7
9
|
|
8
10
|
@pytest.fixture(scope="module")
|
9
11
|
def experiment(experiment_name: str):
|
@@ -28,6 +30,7 @@ def run(rc: RunCollection):
|
|
28
30
|
return rc.first()
|
29
31
|
|
30
32
|
|
33
|
+
@pytest.mark.order(0)
|
31
34
|
def test_hydra_output_dir(run: Run):
|
32
35
|
from hydraflow.utils import get_hydra_output_dir
|
33
36
|
|
@@ -35,6 +38,7 @@ def test_hydra_output_dir(run: Run):
|
|
35
38
|
get_hydra_output_dir(run)
|
36
39
|
|
37
40
|
|
41
|
+
@pytest.mark.order(1)
|
38
42
|
def test_remove_run(rc: RunCollection):
|
39
43
|
from hydraflow.utils import get_artifact_dir, remove_run
|
40
44
|
|
@@ -8,13 +8,20 @@ from hydraflow.run_collection import RunCollection
|
|
8
8
|
if TYPE_CHECKING:
|
9
9
|
from .utils import Config
|
10
10
|
|
11
|
+
pytestmark = pytest.mark.xdist_group(name="group6")
|
12
|
+
|
11
13
|
|
12
14
|
@pytest.fixture(scope="module")
|
13
15
|
def rc(collect):
|
14
16
|
args = ["-m", "name=a,b", "age=10"]
|
17
|
+
|
15
18
|
return collect("utils/utils.py", args)
|
16
19
|
|
17
20
|
|
21
|
+
def test_rc_len(rc: RunCollection):
|
22
|
+
assert len(rc) == 2
|
23
|
+
|
24
|
+
|
18
25
|
@pytest.fixture(scope="module")
|
19
26
|
def run(rc: RunCollection):
|
20
27
|
return rc.first()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|