hydraflow 0.3.2__tar.gz → 0.3.4__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (42) hide show
  1. {hydraflow-0.3.2 → hydraflow-0.3.4}/PKG-INFO +1 -1
  2. {hydraflow-0.3.2 → hydraflow-0.3.4}/pyproject.toml +1 -1
  3. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/__init__.py +11 -3
  4. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/config.py +13 -0
  5. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/context.py +1 -1
  6. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/param.py +4 -1
  7. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/run_collection.py +3 -13
  8. hydraflow-0.3.4/src/hydraflow/run_data.py +57 -0
  9. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/utils.py +25 -0
  10. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/app.py +4 -2
  11. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_app.py +19 -39
  12. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_config.py +21 -0
  13. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_log_run.py +13 -1
  14. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_run_collection.py +17 -0
  15. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_run_data.py +6 -7
  16. hydraflow-0.3.2/src/hydraflow/run_data.py +0 -34
  17. {hydraflow-0.3.2 → hydraflow-0.3.4}/.devcontainer/devcontainer.json +0 -0
  18. {hydraflow-0.3.2 → hydraflow-0.3.4}/.devcontainer/postCreate.sh +0 -0
  19. {hydraflow-0.3.2 → hydraflow-0.3.4}/.devcontainer/starship.toml +0 -0
  20. {hydraflow-0.3.2 → hydraflow-0.3.4}/.gitattributes +0 -0
  21. {hydraflow-0.3.2 → hydraflow-0.3.4}/.gitignore +0 -0
  22. {hydraflow-0.3.2 → hydraflow-0.3.4}/LICENSE +0 -0
  23. {hydraflow-0.3.2 → hydraflow-0.3.4}/README.md +0 -0
  24. {hydraflow-0.3.2 → hydraflow-0.3.4}/apps/quickstart.py +0 -0
  25. {hydraflow-0.3.2 → hydraflow-0.3.4}/mkdocs.yml +0 -0
  26. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/asyncio.py +0 -0
  27. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/mlflow.py +0 -0
  28. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/progress.py +0 -0
  29. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/py.typed +0 -0
  30. {hydraflow-0.3.2 → hydraflow-0.3.4}/src/hydraflow/run_info.py +0 -0
  31. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/__init__.py +0 -0
  32. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/__init__.py +0 -0
  33. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/progress.py +0 -0
  34. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/scripts/watch.py +0 -0
  35. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_asyncio.py +0 -0
  36. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_context.py +0 -0
  37. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_mlflow.py +0 -0
  38. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_param.py +0 -0
  39. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_progress.py +0 -0
  40. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_run_info.py +0 -0
  41. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_version.py +0 -0
  42. {hydraflow-0.3.2 → hydraflow-0.3.4}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.3.2"
7
+ version = "0.3.4"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -1,19 +1,27 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
- from .context import chdir_artifact, chdir_hydra, log_run, start_run, watch
3
+ from .context import chdir_artifact, chdir_hydra_output, log_run, start_run, watch
4
4
  from .mlflow import list_runs, search_runs, set_experiment
5
5
  from .progress import multi_tasks_progress, parallel_progress
6
6
  from .run_collection import RunCollection
7
- from .utils import get_artifact_dir, get_hydra_output_dir, load_config
7
+ from .utils import (
8
+ get_artifact_dir,
9
+ get_hydra_output_dir,
10
+ get_overrides,
11
+ load_config,
12
+ load_overrides,
13
+ )
8
14
 
9
15
  __all__ = [
10
16
  "RunCollection",
11
17
  "chdir_artifact",
12
- "chdir_hydra",
18
+ "chdir_hydra_output",
13
19
  "get_artifact_dir",
14
20
  "get_hydra_output_dir",
21
+ "get_overrides",
15
22
  "list_runs",
16
23
  "load_config",
24
+ "load_overrides",
17
25
  "log_run",
18
26
  "multi_tasks_progress",
19
27
  "parallel_progress",
@@ -44,12 +44,25 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
44
44
  if config is None:
45
45
  return
46
46
 
47
+ if isinstance(config, list) and all(isinstance(x, str) for x in config):
48
+ config = _from_dotlist(config)
49
+
47
50
  if not isinstance(config, DictConfig | ListConfig):
48
51
  config = OmegaConf.create(config) # type: ignore
49
52
 
50
53
  yield from _iter_params(config, prefix)
51
54
 
52
55
 
56
+ def _from_dotlist(config: list[str]) -> dict[str, str]:
57
+ result = {}
58
+ for item in config:
59
+ if "=" in item:
60
+ key, value = item.split("=", 1)
61
+ result[key.strip()] = value.strip()
62
+
63
+ return result
64
+
65
+
53
66
  def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
54
67
  if isinstance(config, DictConfig):
55
68
  for key, value in config.items():
@@ -239,7 +239,7 @@ class Handler(PatternMatchingEventHandler):
239
239
 
240
240
 
241
241
  @contextmanager
242
- def chdir_hydra() -> Iterator[Path]:
242
+ def chdir_hydra_output() -> Iterator[Path]:
243
243
  """Change the current working directory to the hydra output directory.
244
244
 
245
245
  This context manager changes the current working directory to the hydra output
@@ -34,7 +34,10 @@ def match(param: str, value: Any) -> bool:
34
34
  if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
35
35
  return m
36
36
 
37
- if isinstance(value, int | float | str):
37
+ if isinstance(value, str):
38
+ return param == value
39
+
40
+ if isinstance(value, int | float):
38
41
  return type(value)(param) == value
39
42
 
40
43
  return param == str(value)
@@ -24,12 +24,12 @@ from itertools import chain
24
24
  from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
25
25
 
26
26
  from mlflow.entities import RunStatus
27
- from polars.dataframe import DataFrame
28
27
 
29
28
  import hydraflow.param
30
- from hydraflow.config import collect_params, iter_params
29
+ from hydraflow.config import iter_params
31
30
  from hydraflow.run_data import RunCollectionData
32
31
  from hydraflow.run_info import RunCollectionInfo
32
+ from hydraflow.utils import load_config
33
33
 
34
34
  if TYPE_CHECKING:
35
35
  from collections.abc import Callable, Iterator
@@ -516,7 +516,7 @@ class RunCollection:
516
516
  in the collection.
517
517
 
518
518
  """
519
- return (func(config, *args, **kwargs) for config in self.data.config)
519
+ return (func(load_config(run), *args, **kwargs) for run in self)
520
520
 
521
521
  def map_uri(
522
522
  self,
@@ -599,16 +599,6 @@ class RunCollection:
599
599
 
600
600
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
601
601
 
602
- @property
603
- def config(self) -> DataFrame:
604
- """Get the runs' configurations as a polars DataFrame.
605
-
606
- Returns:
607
- A polars DataFrame containing the runs' configurations.
608
-
609
- """
610
- return DataFrame(self.map_config(collect_params))
611
-
612
602
 
613
603
  def _param_matches(run: Run, key: str, value: Any) -> bool:
614
604
  params = run.data.params
@@ -0,0 +1,57 @@
1
+ """Provide data about `RunCollection` instances."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from polars.dataframe import DataFrame
8
+
9
+ from hydraflow.config import collect_params
10
+
11
+ if TYPE_CHECKING:
12
+ from collections.abc import Iterable
13
+ from typing import Any
14
+
15
+ from hydraflow.run_collection import RunCollection
16
+
17
+
18
+ class RunCollectionData:
19
+ """Provide data about a `RunCollection` instance."""
20
+
21
+ def __init__(self, runs: RunCollection) -> None:
22
+ self._runs = runs
23
+
24
+ @property
25
+ def params(self) -> dict[str, list[str]]:
26
+ """Get the parameters for each run in the collection."""
27
+ return _to_dict(run.data.params for run in self._runs)
28
+
29
+ @property
30
+ def metrics(self) -> dict[str, list[float]]:
31
+ """Get the metrics for each run in the collection."""
32
+ return _to_dict(run.data.metrics for run in self._runs)
33
+
34
+ @property
35
+ def config(self) -> DataFrame:
36
+ """Get the runs' configurations as a polars DataFrame.
37
+
38
+ Returns:
39
+ A polars DataFrame containing the runs' configurations.
40
+
41
+ """
42
+ return DataFrame(self._runs.map_config(collect_params))
43
+
44
+
45
+ def _to_dict(it: Iterable[dict[str, Any]]) -> dict[str, list[Any]]:
46
+ """Convert an iterable of dictionaries to a dictionary of lists."""
47
+ data = list(it)
48
+ if not data:
49
+ return {}
50
+
51
+ keys = []
52
+ for d in data:
53
+ for key in d:
54
+ if key not in keys:
55
+ keys.append(key)
56
+
57
+ return {key: [x.get(key) for x in data] for key in keys}
@@ -86,3 +86,28 @@ def load_config(run: Run) -> DictConfig:
86
86
  """
87
87
  path = get_artifact_dir(run) / ".hydra/config.yaml"
88
88
  return OmegaConf.load(path) # type: ignore
89
+
90
+
91
+ def get_overrides() -> list[str]:
92
+ """Retrieve the overrides for the current run."""
93
+ return HydraConfig.get().overrides.task
94
+
95
+
96
+ def load_overrides(run: Run) -> list[str]:
97
+ """Load the overrides for a given run.
98
+
99
+ This function loads the overrides for the provided Run instance
100
+ by downloading the overrides file from the MLflow artifacts and
101
+ loading it using OmegaConf. It returns an empty config if
102
+ `.hydra/overrides.yaml` is not found in the run's artifact directory.
103
+
104
+ Args:
105
+ run (Run): The Run instance for which to load the overrides.
106
+
107
+ Returns:
108
+ The loaded overrides as a list of strings. Returns an empty list
109
+ if the overrides file is not found.
110
+
111
+ """
112
+ path = get_artifact_dir(run) / ".hydra/overrides.yaml"
113
+ return [str(x) for x in OmegaConf.load(path)]
@@ -27,11 +27,11 @@ cs.store(name="config", node=MySQLConfig)
27
27
 
28
28
  @hydra.main(version_base=None, config_name="config")
29
29
  def app(cfg: MySQLConfig):
30
- with hydraflow.chdir_hydra() as path:
30
+ with hydraflow.chdir_hydra_output() as path:
31
31
  Path("chdir_hydra.txt").write_text(path.as_posix())
32
32
 
33
33
  hydraflow.set_experiment(prefix="_", suffix="_")
34
- with hydraflow.start_run(cfg):
34
+ with hydraflow.start_run(cfg) as run:
35
35
  log.info(f"START, {cfg.host}, {cfg.port} ")
36
36
 
37
37
  artifact_dir = hydraflow.get_artifact_dir()
@@ -50,6 +50,8 @@ def app(cfg: MySQLConfig):
50
50
  if cfg.host == "x":
51
51
  mlflow.log_metric("m", cfg.port + 10, 2)
52
52
 
53
+ assert hydraflow.get_overrides() == hydraflow.load_overrides(run)
54
+
53
55
  log.info("END")
54
56
 
55
57
 
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING
8
8
  import mlflow
9
9
  import pytest
10
10
  from mlflow.entities import RunStatus
11
- from omegaconf import ListConfig, OmegaConf
11
+ from omegaconf import OmegaConf
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from omegaconf import DictConfig
@@ -92,33 +92,30 @@ def test_app_info_run_id(rc: RunCollection):
92
92
 
93
93
  def test_app_data_params(rc: RunCollection):
94
94
  params = rc.data.params
95
- assert params[0] == {"port": "1", "host": "x", "values": "[1, 2, 3]"}
96
- assert params[1] == {"port": "2", "host": "x", "values": "[1, 2, 3]"}
97
- assert params[2] == {"port": "1", "host": "y", "values": "[1, 2, 3]"}
98
- assert params[3] == {"port": "2", "host": "y", "values": "[1, 2, 3]"}
95
+ assert params["port"] == ["1", "2", "1", "2"]
96
+ assert params["host"] == ["x", "x", "y", "y"]
97
+ assert params["values"] == ["[1, 2, 3]", "[1, 2, 3]", "[1, 2, 3]", "[1, 2, 3]"]
99
98
 
100
99
 
101
100
  def test_app_data_metrics(rc: RunCollection):
102
101
  metrics = rc.data.metrics
103
- assert metrics[0] == {"m": 11, "watch": 3}
104
- assert metrics[1] == {"m": 12, "watch": 3}
105
- assert metrics[2] == {"m": 2, "watch": 3}
106
- assert metrics[3] == {"m": 3, "watch": 3}
102
+ assert metrics["m"] == [11, 12, 2, 3]
103
+ assert metrics["watch"] == [3, 3, 3, 3]
107
104
 
108
105
 
109
106
  def test_app_data_config(rc: RunCollection):
110
107
  config = rc.data.config
111
- assert config[0].port == 1
112
- assert config[1].port == 2
113
- assert config[2].host == "y"
114
- assert config[3].host == "y"
108
+ assert config["port"].to_list() == [1, 2, 1, 2]
109
+ assert config["host"].to_list() == ["x", "x", "y", "y"]
115
110
 
116
111
 
117
112
  def test_app_data_config_list(rc: RunCollection):
118
113
  config = rc.data.config
119
- assert isinstance(config[0]["values"], ListConfig)
120
- assert not isinstance(config[0]["values"], list)
121
- assert config[0]["values"] == [1, 2, 3]
114
+ values = config["values"].to_list()
115
+ assert str(config.select("values").dtypes) == "[List(Int64)]"
116
+ for x in values:
117
+ assert isinstance(x, list)
118
+ assert x == [1, 2, 3]
122
119
 
123
120
 
124
121
  def test_app_info_artifact_uri(rc: RunCollection):
@@ -160,14 +157,12 @@ def test_app_map_config(rc: RunCollection):
160
157
  def test_app_group_by(rc: RunCollection):
161
158
  grouped = rc.group_by("host")
162
159
  assert len(grouped) == 2
163
- x = {"port": "1", "host": "x", "values": "[1, 2, 3]"}
164
- assert grouped["x"].data.params[0] == x
165
- x = {"port": "2", "host": "x", "values": "[1, 2, 3]"}
166
- assert grouped["x"].data.params[1] == x
167
- x = {"port": "1", "host": "y", "values": "[1, 2, 3]"}
168
- assert grouped["y"].data.params[0] == x
169
- x = {"port": "2", "host": "y", "values": "[1, 2, 3]"}
170
- assert grouped["y"].data.params[1] == x
160
+ assert grouped["x"].data.params["port"] == ["1", "2"]
161
+ assert grouped["x"].data.params["host"] == ["x", "x"]
162
+ assert grouped["x"].data.params["values"] == ["[1, 2, 3]", "[1, 2, 3]"]
163
+ assert grouped["y"].data.params["port"] == ["1", "2"]
164
+ assert grouped["y"].data.params["host"] == ["y", "y"]
165
+ assert grouped["y"].data.params["values"] == ["[1, 2, 3]", "[1, 2, 3]"]
171
166
 
172
167
 
173
168
  def test_app_group_by_list(rc: RunCollection):
@@ -184,18 +179,3 @@ def test_app_filter_list(rc: RunCollection):
184
179
  assert len(filtered) == 4
185
180
  filtered = rc.filter(values=[1])
186
181
  assert not filtered
187
-
188
-
189
- def test_config(rc: RunCollection):
190
- df = rc.config
191
- assert df.columns == ["host", "port", "values"]
192
- assert df.shape == (4, 3)
193
- assert df.select("host").to_series().to_list() == ["x", "x", "y", "y"]
194
- assert df.select("port").to_series().to_list() == [1, 2, 1, 2]
195
- assert str(df.select("values").dtypes) == "[List(Int64)]"
196
- assert df.select("values").to_series().to_list() == [
197
- [1, 2, 3],
198
- [1, 2, 3],
199
- [1, 2, 3],
200
- [1, 2, 3],
201
- ]
@@ -205,3 +205,24 @@ def test_list_config_str(s):
205
205
  assert isinstance(b, ListConfig)
206
206
  t = OmegaConf.create(json.loads(s))
207
207
  assert b == t
208
+
209
+
210
+ @pytest.mark.parametrize("x", [{"a": 1}, {"a": [1, 2, 3]}])
211
+ def test_collect_params_dict(x):
212
+ from hydraflow.config import collect_params
213
+
214
+ assert collect_params(x) == x
215
+
216
+
217
+ def test_collect_params_dict_dot():
218
+ from hydraflow.config import collect_params
219
+
220
+ assert collect_params({"a": {"b": 1}}) == {"a.b": 1}
221
+ assert collect_params({"a.b": 1}) == {"a.b": 1}
222
+
223
+
224
+ def test_collect_params_list_dot():
225
+ from hydraflow.config import collect_params
226
+
227
+ assert collect_params(["a=1"]) == {"a": "1"}
228
+ assert collect_params(["a.b=2", "c"]) == {"a.b": "2"}
@@ -50,7 +50,7 @@ def read_log(run_id: str, path: str) -> str:
50
50
 
51
51
 
52
52
  def test_load_config(run: Run):
53
- from hydraflow.run_data import load_config
53
+ from hydraflow.utils import load_config
54
54
 
55
55
  log = read_log(run.info.run_id, "log_run.log")
56
56
  assert "START" in log
@@ -63,6 +63,18 @@ def test_load_config(run: Run):
63
63
  assert cfg.port == int(port)
64
64
 
65
65
 
66
+ def test_load_overrides(run: Run):
67
+ from hydraflow.utils import load_overrides
68
+
69
+ log = read_log(run.info.run_id, "log_run.log")
70
+ assert "START" in log
71
+ assert "END" in log
72
+
73
+ host, port = log.splitlines()[0].split("START,")[-1].split(",")
74
+
75
+ assert load_overrides(run) == [f"host={host.strip()}", f"port={port.strip()}"]
76
+
77
+
66
78
  def test_info(run: Run):
67
79
  log = read_log(run.info.run_id, "artifact_dir.txt")
68
80
  a, b = log.split(" ")
@@ -67,6 +67,8 @@ def test_filter_one(run_list: list[Run]):
67
67
  assert len(x) == 1
68
68
  x = filter_runs(run_list, p=1)
69
69
  assert len(x) == 1
70
+ x = filter_runs(run_list, ["p=1"])
71
+ assert len(x) == 1
70
72
 
71
73
 
72
74
  def test_filter_all(run_list: list[Run]):
@@ -77,6 +79,8 @@ def test_filter_all(run_list: list[Run]):
77
79
  assert len(x) == 5
78
80
  x = filter_runs(run_list, q=0)
79
81
  assert len(x) == 5
82
+ x = filter_runs(run_list, ["q=0"])
83
+ assert len(x) == 5
80
84
 
81
85
 
82
86
  def test_filter_list(run_list: list[Run]):
@@ -98,6 +102,8 @@ def test_filter_invalid_param(run_list: list[Run]):
98
102
 
99
103
  x = filter_runs(run_list, {"invalid": 0})
100
104
  assert len(x) == 6
105
+ x = filter_runs(run_list, ["invalid=0"])
106
+ assert len(x) == 6
101
107
 
102
108
 
103
109
  def test_filter_status(run_list: list[Run]):
@@ -181,15 +187,20 @@ def test_filter(rc: RunCollection):
181
187
  assert len(rc.filter()) == 6
182
188
  assert len(rc.filter({})) == 6
183
189
  assert len(rc.filter({"p": 1})) == 1
190
+ assert len(rc.filter(["p=1"])) == 1
184
191
  assert len(rc.filter({"q": 0})) == 5
192
+ assert len(rc.filter(["q=0"])) == 5
185
193
  assert len(rc.filter({"q": -1})) == 0
194
+ assert len(rc.filter(["q=-1"])) == 0
186
195
  assert not rc.filter({"q": -1})
187
196
  assert len(rc.filter(p=5)) == 1
188
197
  assert len(rc.filter(q=0)) == 5
189
198
  assert len(rc.filter(q=-1)) == 0
190
199
  assert not rc.filter(q=-1)
191
200
  assert len(rc.filter({"r": 2})) == 2
201
+ assert len(rc.filter(["r=2"])) == 2
192
202
  assert len(rc.filter(r=0)) == 2
203
+ assert len(rc.filter(["r=0"])) == 2
193
204
 
194
205
 
195
206
  def test_get(rc: RunCollection):
@@ -197,15 +208,21 @@ def test_get(rc: RunCollection):
197
208
  assert isinstance(run, Run)
198
209
  run = rc.get(p=2)
199
210
  assert isinstance(run, Run)
211
+ run = rc.get(["p=3"])
212
+ assert isinstance(run, Run)
200
213
 
201
214
 
202
215
  def test_try_get(rc: RunCollection):
203
216
  run = rc.try_get({"p": 5})
204
217
  assert isinstance(run, Run)
218
+ run = rc.try_get(["p=2"])
219
+ assert isinstance(run, Run)
205
220
  run = rc.try_get(p=1)
206
221
  assert isinstance(run, Run)
207
222
  run = rc.try_get(p=-1)
208
223
  assert run is None
224
+ run = rc.try_get(["p=-2"])
225
+ assert run is None
209
226
 
210
227
 
211
228
  def test_get_param_names(rc: RunCollection):
@@ -26,18 +26,17 @@ def runs(monkeypatch, tmp_path):
26
26
 
27
27
 
28
28
  def test_data_params(runs: RunCollection):
29
- assert runs.data.params == [{"p": "0"}, {"p": "1"}, {"p": "2"}]
29
+ assert runs.data.params["p"] == ["0", "1", "2"]
30
30
 
31
31
 
32
32
  def test_data_metrics(runs: RunCollection):
33
33
  m = runs.data.metrics
34
- assert m[0] == {"metric1": 1, "metric2": 2}
35
- assert m[1] == {"metric1": 2, "metric2": 3}
36
- assert m[2] == {"metric1": 3, "metric2": 4}
34
+ assert m["metric1"] == [1, 2, 3]
35
+ assert m["metric2"] == [2, 3, 4]
37
36
 
38
37
 
39
38
  def test_data_empty_run_collection():
40
39
  rc = RunCollection([])
41
- assert rc.data.params == []
42
- assert rc.data.metrics == []
43
- assert rc.data.config == []
40
+ assert rc.data.params == {}
41
+ assert rc.data.metrics == {}
42
+ assert len(rc.data.config) == 0
@@ -1,34 +0,0 @@
1
- """Provide data about `RunCollection` instances."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import TYPE_CHECKING
6
-
7
- from hydraflow.utils import load_config
8
-
9
- if TYPE_CHECKING:
10
- from omegaconf import DictConfig
11
-
12
- from hydraflow.run_collection import RunCollection
13
-
14
-
15
- class RunCollectionData:
16
- """Provide data about a `RunCollection` instance."""
17
-
18
- def __init__(self, runs: RunCollection) -> None:
19
- self._runs = runs
20
-
21
- @property
22
- def params(self) -> list[dict[str, str]]:
23
- """Get the parameters for each run in the collection."""
24
- return [run.data.params for run in self._runs]
25
-
26
- @property
27
- def metrics(self) -> list[dict[str, float]]:
28
- """Get the metrics for each run in the collection."""
29
- return [run.data.metrics for run in self._runs]
30
-
31
- @property
32
- def config(self) -> list[DictConfig]:
33
- """Get the configuration for each run in the collection."""
34
- return [load_config(run) for run in self._runs]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes