hydraflow 0.2.18__tar.gz → 0.3.0__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.2.18 → hydraflow-0.3.0}/PKG-INFO +2 -1
- {hydraflow-0.2.18 → hydraflow-0.3.0}/pyproject.toml +4 -2
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/__init__.py +4 -1
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/config.py +14 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/context.py +1 -1
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/run_collection.py +28 -5
- hydraflow-0.3.0/src/hydraflow/run_data.py +56 -0
- hydraflow-0.2.18/src/hydraflow/info.py → hydraflow-0.3.0/src/hydraflow/run_info.py +1 -36
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_app.py +20 -12
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_config.py +8 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_log_run.py +1 -1
- hydraflow-0.3.0/tests/test_run_data.py +43 -0
- hydraflow-0.2.18/tests/test_info.py → hydraflow-0.3.0/tests/test_run_info.py +1 -17
- {hydraflow-0.2.18 → hydraflow-0.3.0}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/.gitattributes +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/.gitignore +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/LICENSE +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/README.md +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/mkdocs.yml +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/asyncio.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/mlflow.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/param.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/progress.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/src/hydraflow/py.typed +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/__init__.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/scripts/__init__.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/scripts/app.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/scripts/progress.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/scripts/watch.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_asyncio.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_context.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_mlflow.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_param.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_progress.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_run_collection.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_version.py +0 -0
- {hydraflow-0.2.18 → hydraflow-0.3.0}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.3.0
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://github.com/daizutabi/hydraflow
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -17,6 +17,7 @@ Requires-Python: >=3.10
|
|
17
17
|
Requires-Dist: hydra-core>=1.3
|
18
18
|
Requires-Dist: joblib
|
19
19
|
Requires-Dist: mlflow>=2.15
|
20
|
+
Requires-Dist: polars
|
20
21
|
Requires-Dist: rich
|
21
22
|
Requires-Dist: watchdog
|
22
23
|
Requires-Dist: watchfiles
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.
|
7
|
+
version = "0.3.0"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = "MIT"
|
@@ -21,6 +21,7 @@ dependencies = [
|
|
21
21
|
"hydra-core>=1.3",
|
22
22
|
"joblib",
|
23
23
|
"mlflow>=2.15",
|
24
|
+
"polars",
|
24
25
|
"rich",
|
25
26
|
"watchdog",
|
26
27
|
"watchfiles",
|
@@ -80,7 +81,7 @@ ignore = [
|
|
80
81
|
"PGH003",
|
81
82
|
"TRY003",
|
82
83
|
]
|
83
|
-
exclude = ["tests/scripts/*.py"
|
84
|
+
exclude = ["tests/scripts/*.py"]
|
84
85
|
|
85
86
|
[tool.ruff.lint.per-file-ignores]
|
86
87
|
"tests/*" = [
|
@@ -89,6 +90,7 @@ exclude = ["tests/scripts/*.py", "src/hydraflow/__init__.py"]
|
|
89
90
|
"ARG",
|
90
91
|
"D",
|
91
92
|
"FBT",
|
93
|
+
"PD",
|
92
94
|
"PLR",
|
93
95
|
"PT",
|
94
96
|
"S",
|
@@ -1,5 +1,6 @@
|
|
1
|
+
"""Provide a collection of MLflow runs."""
|
2
|
+
|
1
3
|
from .context import chdir_artifact, log_run, start_run, watch
|
2
|
-
from .info import get_artifact_dir, get_hydra_output_dir, load_config
|
3
4
|
from .mlflow import (
|
4
5
|
list_runs,
|
5
6
|
search_runs,
|
@@ -7,6 +8,8 @@ from .mlflow import (
|
|
7
8
|
)
|
8
9
|
from .progress import multi_tasks_progress, parallel_progress
|
9
10
|
from .run_collection import RunCollection
|
11
|
+
from .run_data import load_config
|
12
|
+
from .run_info import get_artifact_dir, get_hydra_output_dir
|
10
13
|
|
11
14
|
__all__ = [
|
12
15
|
"RunCollection",
|
@@ -11,6 +11,20 @@ if TYPE_CHECKING:
|
|
11
11
|
from typing import Any
|
12
12
|
|
13
13
|
|
14
|
+
def collect_params(config: object) -> dict[str, Any]:
|
15
|
+
"""Iterate over parameters and collect them into a dictionary.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
config (object): The configuration object to iterate over.
|
19
|
+
prefix (str): The prefix to prepend to the parameter keys.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
dict[str, Any]: A dictionary of collected parameters.
|
23
|
+
|
24
|
+
"""
|
25
|
+
return dict(iter_params(config))
|
26
|
+
|
27
|
+
|
14
28
|
def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
|
15
29
|
"""Recursively iterate over the parameters in the given configuration object.
|
16
30
|
|
@@ -14,8 +14,8 @@ from hydra.core.hydra_config import HydraConfig
|
|
14
14
|
from watchdog.events import FileModifiedEvent, PatternMatchingEventHandler
|
15
15
|
from watchdog.observers import Observer
|
16
16
|
|
17
|
-
from hydraflow.info import get_artifact_dir
|
18
17
|
from hydraflow.mlflow import log_params
|
18
|
+
from hydraflow.run_info import get_artifact_dir
|
19
19
|
|
20
20
|
if TYPE_CHECKING:
|
21
21
|
from collections.abc import Callable, Iterator
|
@@ -24,10 +24,12 @@ from itertools import chain
|
|
24
24
|
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
25
25
|
|
26
26
|
from mlflow.entities import RunStatus
|
27
|
+
from polars.dataframe import DataFrame
|
27
28
|
|
28
29
|
import hydraflow.param
|
29
|
-
from hydraflow.config import iter_params
|
30
|
-
from hydraflow.
|
30
|
+
from hydraflow.config import collect_params, iter_params
|
31
|
+
from hydraflow.run_data import RunCollectionData
|
32
|
+
from hydraflow.run_info import RunCollectionInfo
|
31
33
|
|
32
34
|
if TYPE_CHECKING:
|
33
35
|
from collections.abc import Callable, Iterator
|
@@ -61,8 +63,12 @@ class RunCollection:
|
|
61
63
|
_info: RunCollectionInfo = field(init=False)
|
62
64
|
"""An instance of `RunCollectionInfo`."""
|
63
65
|
|
66
|
+
_data: RunCollectionData = field(init=False)
|
67
|
+
"""An instance of `RunCollectionData`."""
|
68
|
+
|
64
69
|
def __post_init__(self) -> None:
|
65
70
|
self._info = RunCollectionInfo(self)
|
71
|
+
self._data = RunCollectionData(self)
|
66
72
|
|
67
73
|
def __repr__(self) -> str:
|
68
74
|
return f"{self.__class__.__name__}({len(self)})"
|
@@ -101,6 +107,11 @@ class RunCollection:
|
|
101
107
|
"""An instance of `RunCollectionInfo`."""
|
102
108
|
return self._info
|
103
109
|
|
110
|
+
@property
|
111
|
+
def data(self) -> RunCollectionData:
|
112
|
+
"""An instance of `RunCollectionData`."""
|
113
|
+
return self._data
|
114
|
+
|
104
115
|
def take(self, n: int) -> RunCollection:
|
105
116
|
"""Take the first n runs from the collection.
|
106
117
|
|
@@ -371,7 +382,7 @@ class RunCollection:
|
|
371
382
|
raise ValueError(msg)
|
372
383
|
|
373
384
|
def try_get(self, config: object | None = None, **kwargs) -> Run | None:
|
374
|
-
"""Try to
|
385
|
+
"""Try to get a specific `Run` instance based on the provided configuration.
|
375
386
|
|
376
387
|
This method filters the runs in the collection according to the
|
377
388
|
specified configuration object and returns the run that matches the
|
@@ -505,7 +516,7 @@ class RunCollection:
|
|
505
516
|
in the collection.
|
506
517
|
|
507
518
|
"""
|
508
|
-
return (func(config, *args, **kwargs) for config in self.
|
519
|
+
return (func(config, *args, **kwargs) for config in self.data.config)
|
509
520
|
|
510
521
|
def map_uri(
|
511
522
|
self,
|
@@ -584,6 +595,16 @@ class RunCollection:
|
|
584
595
|
|
585
596
|
return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
|
586
597
|
|
598
|
+
@property
|
599
|
+
def config(self) -> DataFrame:
|
600
|
+
"""Get the runs' configurations as a polars DataFrame.
|
601
|
+
|
602
|
+
Returns:
|
603
|
+
A polars DataFrame containing the runs' configurations.
|
604
|
+
|
605
|
+
"""
|
606
|
+
return DataFrame(self.map_config(collect_params))
|
607
|
+
|
587
608
|
|
588
609
|
def _param_matches(run: Run, key: str, value: Any) -> bool:
|
589
610
|
params = run.data.params
|
@@ -634,8 +655,10 @@ def filter_runs(
|
|
634
655
|
"""
|
635
656
|
for key, value in chain(iter_params(config), kwargs.items()):
|
636
657
|
runs = [run for run in runs if _param_matches(run, key, value)]
|
658
|
+
if not runs:
|
659
|
+
return []
|
637
660
|
|
638
|
-
if
|
661
|
+
if status is None:
|
639
662
|
return runs
|
640
663
|
|
641
664
|
return filter_runs_by_status(runs, status)
|
@@ -0,0 +1,56 @@
|
|
1
|
+
"""Provide information about MLflow runs."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import TYPE_CHECKING
|
6
|
+
|
7
|
+
from omegaconf import DictConfig, OmegaConf
|
8
|
+
|
9
|
+
from hydraflow.run_info import get_artifact_dir
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from mlflow.entities import Run
|
13
|
+
|
14
|
+
from hydraflow.run_collection import RunCollection
|
15
|
+
|
16
|
+
|
17
|
+
class RunCollectionData:
|
18
|
+
"""Provide information about MLflow runs."""
|
19
|
+
|
20
|
+
def __init__(self, runs: RunCollection) -> None:
|
21
|
+
self._runs = runs
|
22
|
+
|
23
|
+
@property
|
24
|
+
def params(self) -> list[dict[str, str]]:
|
25
|
+
"""Get the parameters for each run in the collection."""
|
26
|
+
return [run.data.params for run in self._runs]
|
27
|
+
|
28
|
+
@property
|
29
|
+
def metrics(self) -> list[dict[str, float]]:
|
30
|
+
"""Get the metrics for each run in the collection."""
|
31
|
+
return [run.data.metrics for run in self._runs]
|
32
|
+
|
33
|
+
@property
|
34
|
+
def config(self) -> list[DictConfig]:
|
35
|
+
"""Get the configuration for each run in the collection."""
|
36
|
+
return [load_config(run) for run in self._runs]
|
37
|
+
|
38
|
+
|
39
|
+
def load_config(run: Run) -> DictConfig:
|
40
|
+
"""Load the configuration for a given run.
|
41
|
+
|
42
|
+
This function loads the configuration for the provided Run instance
|
43
|
+
by downloading the configuration file from the MLflow artifacts and
|
44
|
+
loading it using OmegaConf. It returns an empty config if
|
45
|
+
`.hydra/config.yaml` is not found in the run's artifact directory.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
run (Run): The Run instance for which to load the configuration.
|
49
|
+
|
50
|
+
Returns:
|
51
|
+
The loaded configuration as a DictConfig object. Returns an empty
|
52
|
+
DictConfig if the configuration file is not found.
|
53
|
+
|
54
|
+
"""
|
55
|
+
path = get_artifact_dir(run) / ".hydra/config.yaml"
|
56
|
+
return OmegaConf.load(path) # type: ignore
|
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
|
|
8
8
|
import mlflow
|
9
9
|
from hydra.core.hydra_config import HydraConfig
|
10
10
|
from mlflow.tracking import artifact_utils
|
11
|
-
from omegaconf import
|
11
|
+
from omegaconf import OmegaConf
|
12
12
|
|
13
13
|
if TYPE_CHECKING:
|
14
14
|
from mlflow.entities import Run
|
@@ -27,16 +27,6 @@ class RunCollectionInfo:
|
|
27
27
|
"""Get the run ID for each run in the collection."""
|
28
28
|
return [run.info.run_id for run in self._runs]
|
29
29
|
|
30
|
-
@property
|
31
|
-
def params(self) -> list[dict[str, str]]:
|
32
|
-
"""Get the parameters for each run in the collection."""
|
33
|
-
return [run.data.params for run in self._runs]
|
34
|
-
|
35
|
-
@property
|
36
|
-
def metrics(self) -> list[dict[str, float]]:
|
37
|
-
"""Get the metrics for each run in the collection."""
|
38
|
-
return [run.data.metrics for run in self._runs]
|
39
|
-
|
40
30
|
@property
|
41
31
|
def artifact_uri(self) -> list[str | None]:
|
42
32
|
"""Get the artifact URI for each run in the collection."""
|
@@ -47,11 +37,6 @@ class RunCollectionInfo:
|
|
47
37
|
"""Get the artifact directory for each run in the collection."""
|
48
38
|
return [get_artifact_dir(run) for run in self._runs]
|
49
39
|
|
50
|
-
@property
|
51
|
-
def config(self) -> list[DictConfig]:
|
52
|
-
"""Get the configuration for each run in the collection."""
|
53
|
-
return [load_config(run) for run in self._runs]
|
54
|
-
|
55
40
|
|
56
41
|
def get_artifact_dir(run: Run | None = None) -> Path:
|
57
42
|
"""Retrieve the artifact directory for the given run.
|
@@ -104,23 +89,3 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
|
|
104
89
|
return Path(hc.hydra.runtime.output_dir)
|
105
90
|
|
106
91
|
raise FileNotFoundError
|
107
|
-
|
108
|
-
|
109
|
-
def load_config(run: Run) -> DictConfig:
|
110
|
-
"""Load the configuration for a given run.
|
111
|
-
|
112
|
-
This function loads the configuration for the provided Run instance
|
113
|
-
by downloading the configuration file from the MLflow artifacts and
|
114
|
-
loading it using OmegaConf. It returns an empty config if
|
115
|
-
`.hydra/config.yaml` is not found in the run's artifact directory.
|
116
|
-
|
117
|
-
Args:
|
118
|
-
run (Run): The Run instance for which to load the configuration.
|
119
|
-
|
120
|
-
Returns:
|
121
|
-
The loaded configuration as a DictConfig object. Returns an empty
|
122
|
-
DictConfig if the configuration file is not found.
|
123
|
-
|
124
|
-
"""
|
125
|
-
path = get_artifact_dir(run) / ".hydra/config.yaml"
|
126
|
-
return OmegaConf.load(path) # type: ignore
|
@@ -90,24 +90,24 @@ def test_app_info_run_id(rc: RunCollection):
|
|
90
90
|
assert len(rc.info.run_id) == 4
|
91
91
|
|
92
92
|
|
93
|
-
def
|
94
|
-
params = rc.
|
93
|
+
def test_app_data_params(rc: RunCollection):
|
94
|
+
params = rc.data.params
|
95
95
|
assert params[0] == {"port": "1", "host": "x", "values": "[1, 2, 3]"}
|
96
96
|
assert params[1] == {"port": "2", "host": "x", "values": "[1, 2, 3]"}
|
97
97
|
assert params[2] == {"port": "1", "host": "y", "values": "[1, 2, 3]"}
|
98
98
|
assert params[3] == {"port": "2", "host": "y", "values": "[1, 2, 3]"}
|
99
99
|
|
100
100
|
|
101
|
-
def
|
102
|
-
metrics = rc.
|
101
|
+
def test_app_data_metrics(rc: RunCollection):
|
102
|
+
metrics = rc.data.metrics
|
103
103
|
assert metrics[0] == {"m": 11, "watch": 3}
|
104
104
|
assert metrics[1] == {"m": 12, "watch": 3}
|
105
105
|
assert metrics[2] == {"m": 2, "watch": 3}
|
106
106
|
assert metrics[3] == {"m": 3, "watch": 3}
|
107
107
|
|
108
108
|
|
109
|
-
def
|
110
|
-
config = rc.
|
109
|
+
def test_app_data_config(rc: RunCollection):
|
110
|
+
config = rc.data.config
|
111
111
|
assert config[0].port == 1
|
112
112
|
assert config[1].port == 2
|
113
113
|
assert config[2].host == "y"
|
@@ -122,14 +122,14 @@ def test_app_info_artifact_uri(rc: RunCollection):
|
|
122
122
|
|
123
123
|
|
124
124
|
def test_app_info_artifact_dir(rc: RunCollection):
|
125
|
-
from hydraflow.
|
125
|
+
from hydraflow.run_info import get_artifact_dir
|
126
126
|
|
127
127
|
dirs = list(rc.map(get_artifact_dir))
|
128
128
|
assert rc.info.artifact_dir == dirs
|
129
129
|
|
130
130
|
|
131
131
|
def test_app_hydra_output_dir(rc: RunCollection):
|
132
|
-
from hydraflow.
|
132
|
+
from hydraflow.run_info import get_hydra_output_dir
|
133
133
|
|
134
134
|
dirs = list(rc.map(get_hydra_output_dir))
|
135
135
|
assert dirs[0].stem == "0"
|
@@ -154,13 +154,13 @@ def test_app_group_by(rc: RunCollection):
|
|
154
154
|
grouped = rc.group_by("host")
|
155
155
|
assert len(grouped) == 2
|
156
156
|
x = {"port": "1", "host": "x", "values": "[1, 2, 3]"}
|
157
|
-
assert grouped[("x",)].
|
157
|
+
assert grouped[("x",)].data.params[0] == x
|
158
158
|
x = {"port": "2", "host": "x", "values": "[1, 2, 3]"}
|
159
|
-
assert grouped[("x",)].
|
159
|
+
assert grouped[("x",)].data.params[1] == x
|
160
160
|
x = {"port": "1", "host": "y", "values": "[1, 2, 3]"}
|
161
|
-
assert grouped[("y",)].
|
161
|
+
assert grouped[("y",)].data.params[0] == x
|
162
162
|
x = {"port": "2", "host": "y", "values": "[1, 2, 3]"}
|
163
|
-
assert grouped[("y",)].
|
163
|
+
assert grouped[("y",)].data.params[1] == x
|
164
164
|
|
165
165
|
|
166
166
|
def test_app_filter_list(rc: RunCollection):
|
@@ -170,3 +170,11 @@ def test_app_filter_list(rc: RunCollection):
|
|
170
170
|
assert len(filtered) == 4
|
171
171
|
filtered = rc.filter(values=[1])
|
172
172
|
assert not filtered
|
173
|
+
|
174
|
+
|
175
|
+
def test_config(rc: RunCollection):
|
176
|
+
df = rc.config
|
177
|
+
assert df.columns == ["host", "port", "values"]
|
178
|
+
assert df.shape == (4, 3)
|
179
|
+
assert df.select("host").to_series().to_list() == ["x", "x", "y", "y"]
|
180
|
+
assert df.select("port").to_series().to_list() == [1, 2, 1, 2]
|
@@ -87,6 +87,14 @@ def test_iter_params():
|
|
87
87
|
assert next(it) == ("l.1.3", "c")
|
88
88
|
|
89
89
|
|
90
|
+
def test_collect_params():
|
91
|
+
from hydraflow.config import collect_params
|
92
|
+
|
93
|
+
conf = OmegaConf.create({"k": "v", "l": [1, {"a": "1", "b": "2", 3: "c"}]})
|
94
|
+
params = collect_params(conf)
|
95
|
+
assert params == {"k": "v", "l.0": 1, "l.1.a": "1", "l.1.b": "2", "l.1.3": "c"}
|
96
|
+
|
97
|
+
|
90
98
|
@dataclass
|
91
99
|
class Size:
|
92
100
|
x: int = 1
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import mlflow
|
4
|
+
import pytest
|
5
|
+
|
6
|
+
from hydraflow.run_collection import RunCollection
|
7
|
+
|
8
|
+
|
9
|
+
@pytest.fixture
|
10
|
+
def runs(monkeypatch, tmp_path):
|
11
|
+
from hydraflow.mlflow import search_runs
|
12
|
+
|
13
|
+
monkeypatch.chdir(tmp_path)
|
14
|
+
|
15
|
+
mlflow.set_experiment("test_info")
|
16
|
+
|
17
|
+
for x in range(3):
|
18
|
+
with mlflow.start_run(run_name=f"{x}"):
|
19
|
+
mlflow.log_param("p", x)
|
20
|
+
mlflow.log_metric("metric1", x + 1)
|
21
|
+
mlflow.log_metric("metric2", x + 2)
|
22
|
+
|
23
|
+
x = search_runs()
|
24
|
+
assert isinstance(x, RunCollection)
|
25
|
+
return x
|
26
|
+
|
27
|
+
|
28
|
+
def test_data_params(runs: RunCollection):
|
29
|
+
assert runs.data.params == [{"p": "0"}, {"p": "1"}, {"p": "2"}]
|
30
|
+
|
31
|
+
|
32
|
+
def test_data_metrics(runs: RunCollection):
|
33
|
+
m = runs.data.metrics
|
34
|
+
assert m[0] == {"metric1": 1, "metric2": 2}
|
35
|
+
assert m[1] == {"metric1": 2, "metric2": 3}
|
36
|
+
assert m[2] == {"metric1": 3, "metric2": 4}
|
37
|
+
|
38
|
+
|
39
|
+
def test_data_empty_run_collection():
|
40
|
+
rc = RunCollection([])
|
41
|
+
assert rc.data.params == []
|
42
|
+
assert rc.data.metrics == []
|
43
|
+
assert rc.data.config == []
|
@@ -18,9 +18,7 @@ def runs(monkeypatch, tmp_path):
|
|
18
18
|
|
19
19
|
for x in range(3):
|
20
20
|
with mlflow.start_run(run_name=f"{x}"):
|
21
|
-
|
22
|
-
mlflow.log_metric("metric1", x + 1)
|
23
|
-
mlflow.log_metric("metric2", x + 2)
|
21
|
+
pass
|
24
22
|
|
25
23
|
x = search_runs()
|
26
24
|
assert isinstance(x, RunCollection)
|
@@ -31,17 +29,6 @@ def test_info_run_id(runs: RunCollection):
|
|
31
29
|
assert len(runs.info.run_id) == 3
|
32
30
|
|
33
31
|
|
34
|
-
def test_info_params(runs: RunCollection):
|
35
|
-
assert runs.info.params == [{"p": "0"}, {"p": "1"}, {"p": "2"}]
|
36
|
-
|
37
|
-
|
38
|
-
def test_info_metrics(runs: RunCollection):
|
39
|
-
m = runs.info.metrics
|
40
|
-
assert m[0] == {"metric1": 1, "metric2": 2}
|
41
|
-
assert m[1] == {"metric1": 2, "metric2": 3}
|
42
|
-
assert m[2] == {"metric1": 3, "metric2": 4}
|
43
|
-
|
44
|
-
|
45
32
|
def test_info_artifact_uri(runs: RunCollection):
|
46
33
|
uri = runs.info.artifact_uri
|
47
34
|
assert all(u.startswith("file://") for u in uri) # type: ignore
|
@@ -57,8 +44,5 @@ def test_info_artifact_dir(runs: RunCollection):
|
|
57
44
|
def test_info_empty_run_collection():
|
58
45
|
rc = RunCollection([])
|
59
46
|
assert rc.info.run_id == []
|
60
|
-
assert rc.info.params == []
|
61
|
-
assert rc.info.metrics == []
|
62
47
|
assert rc.info.artifact_uri == []
|
63
48
|
assert rc.info.artifact_dir == []
|
64
|
-
assert rc.info.config == []
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|