hydraflow 0.2.16__tar.gz → 0.2.17__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (37) hide show
  1. {hydraflow-0.2.16 → hydraflow-0.2.17}/PKG-INFO +1 -1
  2. {hydraflow-0.2.16 → hydraflow-0.2.17}/pyproject.toml +2 -1
  3. hydraflow-0.2.17/src/hydraflow/param.py +64 -0
  4. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/run_collection.py +10 -28
  5. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/scripts/app.py +2 -1
  6. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_app.py +22 -8
  7. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_config.py +22 -1
  8. hydraflow-0.2.17/tests/test_param.py +78 -0
  9. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_run_collection.py +130 -117
  10. {hydraflow-0.2.16 → hydraflow-0.2.17}/.devcontainer/devcontainer.json +0 -0
  11. {hydraflow-0.2.16 → hydraflow-0.2.17}/.devcontainer/postCreate.sh +0 -0
  12. {hydraflow-0.2.16 → hydraflow-0.2.17}/.devcontainer/starship.toml +0 -0
  13. {hydraflow-0.2.16 → hydraflow-0.2.17}/.gitattributes +0 -0
  14. {hydraflow-0.2.16 → hydraflow-0.2.17}/.gitignore +0 -0
  15. {hydraflow-0.2.16 → hydraflow-0.2.17}/LICENSE +0 -0
  16. {hydraflow-0.2.16 → hydraflow-0.2.17}/README.md +0 -0
  17. {hydraflow-0.2.16 → hydraflow-0.2.17}/mkdocs.yml +0 -0
  18. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/__init__.py +0 -0
  19. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/asyncio.py +0 -0
  20. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/config.py +0 -0
  21. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/context.py +0 -0
  22. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/info.py +0 -0
  23. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/mlflow.py +0 -0
  24. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/progress.py +0 -0
  25. {hydraflow-0.2.16 → hydraflow-0.2.17}/src/hydraflow/py.typed +0 -0
  26. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/__init__.py +0 -0
  27. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/scripts/__init__.py +0 -0
  28. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/scripts/progress.py +0 -0
  29. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/scripts/watch.py +0 -0
  30. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_asyncio.py +0 -0
  31. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_context.py +0 -0
  32. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_info.py +0 -0
  33. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_log_run.py +0 -0
  34. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_mlflow.py +0 -0
  35. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_progress.py +0 -0
  36. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_version.py +0 -0
  37. {hydraflow-0.2.16 → hydraflow-0.2.17}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.16
3
+ Version: 0.2.17
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.16"
7
+ version = "0.2.17"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -66,6 +66,7 @@ target-version = "py310"
66
66
 
67
67
  [tool.ruff.lint]
68
68
  select = ["ALL"]
69
+ unfixable = ["F401"]
69
70
  ignore = [
70
71
  "ANN003",
71
72
  "ANN401",
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def match(param: str, value: Any) -> bool:
7
+ """Check if the string matches the specified value.
8
+
9
+ Args:
10
+ param (str): The parameter to check.
11
+ value (Any): The value to check.
12
+
13
+ Returns:
14
+ True if the parameter matches the specified value,
15
+ False otherwise.
16
+ """
17
+ if value in [None, True, False]:
18
+ return param == str(value)
19
+
20
+ if isinstance(value, list) and (m := _match_list(param, value)) is not None:
21
+ return m
22
+
23
+ if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
24
+ return m
25
+
26
+ if isinstance(value, int | float | str):
27
+ return type(value)(param) == value
28
+
29
+ return param == str(value)
30
+
31
+
32
+ def _match_list(param: str, value: list) -> bool | None:
33
+ if not value:
34
+ return None
35
+
36
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
37
+ return None
38
+
39
+ if isinstance(value[0], bool):
40
+ return None
41
+
42
+ if not isinstance(value[0], int | float | str):
43
+ return None
44
+
45
+ return type(value[0])(param) in value
46
+
47
+
48
+ def _match_tuple(param: str, value: tuple) -> bool | None:
49
+ if len(value) != 2: # noqa: PLR2004
50
+ return None
51
+
52
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
53
+ return None
54
+
55
+ if isinstance(value[0], bool):
56
+ return None
57
+
58
+ if not isinstance(value[0], int | float | str):
59
+ return None
60
+
61
+ if type(value[0]) is not type(value[1]):
62
+ return None
63
+
64
+ return value[0] <= type(value[0])(param) < value[1] # type: ignore
@@ -23,6 +23,7 @@ from dataclasses import dataclass, field
23
23
  from itertools import chain
24
24
  from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
25
25
 
26
+ import hydraflow.param
26
27
  from hydraflow.config import iter_params
27
28
  from hydraflow.info import RunCollectionInfo
28
29
 
@@ -86,6 +87,9 @@ class RunCollection:
86
87
  def __contains__(self, run: Run) -> bool:
87
88
  return run in self._runs
88
89
 
90
+ def __bool__(self) -> bool:
91
+ return bool(self._runs)
92
+
89
93
  @classmethod
90
94
  def from_list(cls, runs: list[Run]) -> RunCollection:
91
95
  """Create a `RunCollection` instance from a list of MLflow `Run` instances."""
@@ -569,37 +573,15 @@ class RunCollection:
569
573
 
570
574
 
571
575
  def _param_matches(run: Run, key: str, value: Any) -> bool:
572
- """
573
- Check if the run's parameter matches the specified key-value pair.
574
-
575
- Check if the run's parameters contain the specified
576
- key-value pair. It handles different types of values, including lists
577
- and tuples.
578
-
579
- Args:
580
- run (Run): The run object to check.
581
- key (str): The parameter key to check.
582
- value (Any): The parameter value to check.
583
-
584
- Returns:
585
- True if the run's parameter matches the specified key-value pair,
586
- False otherwise.
587
- """
588
- param = run.data.params.get(key, value)
589
-
590
- if param is None:
591
- return False
576
+ params = run.data.params
577
+ if key not in params:
578
+ return True
592
579
 
580
+ param = params[key]
593
581
  if param == "None":
594
- return value is None
595
-
596
- if isinstance(value, list) and value:
597
- return type(value[0])(param) in value
598
-
599
- if isinstance(value, tuple) and len(value) == 2: # noqa: PLR2004
600
- return value[0] <= type(value[0])(param) < value[1]
582
+ return value is None or value == "None"
601
583
 
602
- return type(value)(param) == value
584
+ return hydraflow.param.match(param, value)
603
585
 
604
586
 
605
587
  def filter_runs(
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import time
5
- from dataclasses import dataclass
5
+ from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
7
 
8
8
  import hydra
@@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
18
18
  class MySQLConfig:
19
19
  host: str = "localhost"
20
20
  port: int = 3306
21
+ values: list[int] = field(default_factory=lambda: [1, 2, 3])
21
22
 
22
23
 
23
24
  cs = ConfigStore.instance()
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING
7
7
 
8
8
  import mlflow
9
9
  import pytest
10
+ from omegaconf import OmegaConf
10
11
 
11
12
  if TYPE_CHECKING:
12
13
  from omegaconf import DictConfig
@@ -77,10 +78,10 @@ def test_app_info_run_id(rc: RunCollection):
77
78
 
78
79
  def test_app_info_params(rc: RunCollection):
79
80
  params = rc.info.params
80
- assert params[0] == {"port": "1", "host": "x"}
81
- assert params[1] == {"port": "2", "host": "x"}
82
- assert params[2] == {"port": "1", "host": "y"}
83
- assert params[3] == {"port": "2", "host": "y"}
81
+ assert params[0] == {"port": "1", "host": "x", "values": "[1, 2, 3]"}
82
+ assert params[1] == {"port": "2", "host": "x", "values": "[1, 2, 3]"}
83
+ assert params[2] == {"port": "1", "host": "y", "values": "[1, 2, 3]"}
84
+ assert params[3] == {"port": "2", "host": "y", "values": "[1, 2, 3]"}
84
85
 
85
86
 
86
87
  def test_app_info_metrics(rc: RunCollection):
@@ -138,7 +139,20 @@ def test_app_map_config(rc: RunCollection):
138
139
  def test_app_group_by(rc: RunCollection):
139
140
  grouped = rc.group_by("host")
140
141
  assert len(grouped) == 2
141
- assert grouped[("x",)].info.params[0] == {"port": "1", "host": "x"}
142
- assert grouped[("x",)].info.params[1] == {"port": "2", "host": "x"}
143
- assert grouped[("y",)].info.params[0] == {"port": "1", "host": "y"}
144
- assert grouped[("y",)].info.params[1] == {"port": "2", "host": "y"}
142
+ x = {"port": "1", "host": "x", "values": "[1, 2, 3]"}
143
+ assert grouped[("x",)].info.params[0] == x
144
+ x = {"port": "2", "host": "x", "values": "[1, 2, 3]"}
145
+ assert grouped[("x",)].info.params[1] == x
146
+ x = {"port": "1", "host": "y", "values": "[1, 2, 3]"}
147
+ assert grouped[("y",)].info.params[0] == x
148
+ x = {"port": "2", "host": "y", "values": "[1, 2, 3]"}
149
+ assert grouped[("y",)].info.params[1] == x
150
+
151
+
152
+ def test_app_filter_list(rc: RunCollection):
153
+ filtered = rc.filter(values=[1, 2, 3])
154
+ assert len(filtered) == 4
155
+ filtered = rc.filter(values=OmegaConf.create([1, 2, 3]))
156
+ assert len(filtered) == 4
157
+ filtered = rc.filter(values=[1])
158
+ assert not filtered
@@ -1,7 +1,8 @@
1
+ import json
1
2
  from dataclasses import dataclass, field
2
3
 
3
4
  import pytest
4
- from omegaconf import OmegaConf
5
+ from omegaconf import ListConfig, OmegaConf
5
6
 
6
7
 
7
8
  def test_is_param_with_simple_values():
@@ -176,3 +177,23 @@ def test_iter_params_with_mixed_types_in_list():
176
177
  assert next(it) == ("items.0", "a")
177
178
  assert next(it) == ("items.1", 1)
178
179
  assert next(it) == ("items.2.key", "value")
180
+
181
+
182
+ @pytest.mark.parametrize("type_", [int, float])
183
+ @pytest.mark.parametrize("s", ["[1, 2, 3]", "[1.0, 2.0, 3.0]"])
184
+ def test_list_config(type_, s):
185
+ a = [type_(x) for x in [1, 2, 3]]
186
+ b = OmegaConf.create(a)
187
+ assert isinstance(b, ListConfig)
188
+ t = OmegaConf.create(json.loads(s))
189
+ assert b == t
190
+ assert a == t
191
+
192
+
193
+ @pytest.mark.parametrize("s", ['["a", "b", "c"]'])
194
+ def test_list_config_str(s):
195
+ a = ["a", "b", "c"]
196
+ b = OmegaConf.create(a)
197
+ assert isinstance(b, ListConfig)
198
+ t = OmegaConf.create(json.loads(s))
199
+ assert b == t
@@ -0,0 +1,78 @@
1
+ import mlflow
2
+ import pytest
3
+
4
+
5
+ @pytest.fixture
6
+ def param(monkeypatch, tmp_path):
7
+ def param(value):
8
+ monkeypatch.chdir(tmp_path)
9
+ mlflow.set_experiment("test_param")
10
+
11
+ with mlflow.start_run():
12
+ mlflow.log_param("p", value, synchronous=True)
13
+
14
+ runs = mlflow.search_runs(output_format="list")
15
+ p = runs[0].data.params["p"]
16
+ assert isinstance(p, str)
17
+ return p
18
+
19
+ return param
20
+
21
+
22
+ @pytest.mark.parametrize(
23
+ ("x", "y"),
24
+ [
25
+ (1, "1"),
26
+ (1.0, "1.0"),
27
+ ("1", "1"),
28
+ ("a", "a"),
29
+ ("'a'", "'a'"),
30
+ ('"a"', '"a"'),
31
+ (True, "True"),
32
+ (False, "False"),
33
+ (None, "None"),
34
+ ([], "[]"),
35
+ ((), "()"),
36
+ ({}, "{}"),
37
+ ([1, 2, 3], "[1, 2, 3]"),
38
+ (["1", "2", "3"], "['1', '2', '3']"),
39
+ (("1", "2", "3"), "('1', '2', '3')"),
40
+ ({"a": 1, "b": "c"}, "{'a': 1, 'b': 'c'}"),
41
+ ],
42
+ )
43
+ def test_param(param, x, y):
44
+ from hydraflow.param import match
45
+
46
+ p = param(x)
47
+ assert p == y
48
+ assert str(x) == y
49
+ assert match(p, x)
50
+
51
+
52
+ def test_match_list():
53
+ from hydraflow.param import _match_list
54
+
55
+ assert _match_list("1", [1, 2, 3]) is True
56
+ assert _match_list("[1]", [1, 2, 3]) is None
57
+ assert _match_list("(1,)", [1, 2, 3]) is None
58
+ assert _match_list("{1: 3}", [1, 2, 3]) is None
59
+ assert _match_list("2", [1, 2, 3]) is True
60
+ assert _match_list("4", [1, 2, 3]) is False
61
+ assert _match_list("4", [True]) is None
62
+ assert _match_list("4", [None]) is None
63
+ assert _match_list("4", ["4"]) is True
64
+ assert _match_list("4", ["a"]) is False
65
+
66
+
67
+ def test_match_tuple():
68
+ from hydraflow.param import _match_tuple
69
+
70
+ assert _match_tuple("1", (1, 3)) is True
71
+ assert _match_tuple("2", (1, 3)) is True
72
+ assert _match_tuple("4", (1, 3)) is False
73
+ assert _match_tuple("[1]", (1, 3)) is None
74
+ assert _match_tuple("(1,)", (1, 3)) is None
75
+ assert _match_tuple("{1: 3}", (1, 3)) is None
76
+ assert _match_tuple("1", (True, False)) is None
77
+ assert _match_tuple("1", (None, None)) is None
78
+ assert _match_tuple("1", (1, 3.2)) is None
@@ -10,7 +10,7 @@ from hydraflow.run_collection import RunCollection
10
10
 
11
11
 
12
12
  @pytest.fixture
13
- def runs(monkeypatch, tmp_path):
13
+ def rc(monkeypatch, tmp_path):
14
14
  from hydraflow.mlflow import search_runs
15
15
 
16
16
  monkeypatch.chdir(tmp_path)
@@ -28,9 +28,19 @@ def runs(monkeypatch, tmp_path):
28
28
  return x
29
29
 
30
30
 
31
+ def test_run_collection_bool_false():
32
+ assert not RunCollection([])
33
+ assert bool(RunCollection.from_list([])) is False
34
+
35
+
36
+ def test_run_collection_bool_true(rc: RunCollection):
37
+ assert rc
38
+ assert bool(rc) is True
39
+
40
+
31
41
  @pytest.fixture
32
- def run_list(runs: RunCollection):
33
- return runs._runs
42
+ def run_list(rc: RunCollection):
43
+ return rc._runs
34
44
 
35
45
 
36
46
  def test_from_list(run_list: list[Run]):
@@ -120,123 +130,125 @@ def test_chdir_artifact_list(i: int, run_list: list[Run]):
120
130
  assert not Path("abc.txt").exists()
121
131
 
122
132
 
123
- def test_runs_repr(runs: RunCollection):
124
- assert repr(runs) == "RunCollection(6)"
133
+ def test_runs_repr(rc: RunCollection):
134
+ assert repr(rc) == "RunCollection(6)"
125
135
 
126
136
 
127
- def test_runs_first(runs: RunCollection):
128
- run = runs.first()
137
+ def test_runs_first(rc: RunCollection):
138
+ run = rc.first()
129
139
  assert isinstance(run, Run)
130
140
  assert run.data.params["p"] == "0"
131
141
 
132
142
 
133
- def test_runs_first_empty(runs: RunCollection):
134
- runs._runs = []
143
+ def test_runs_first_empty(rc: RunCollection):
144
+ rc._runs = []
135
145
  with pytest.raises(ValueError):
136
- runs.first()
146
+ rc.first()
137
147
 
138
148
 
139
- def test_runs_try_first_none(runs: RunCollection):
140
- runs._runs = []
141
- assert runs.try_first() is None
149
+ def test_runs_try_first_none(rc: RunCollection):
150
+ rc._runs = []
151
+ assert rc.try_first() is None
142
152
 
143
153
 
144
- def test_runs_last(runs: RunCollection):
145
- run = runs.last()
154
+ def test_runs_last(rc: RunCollection):
155
+ run = rc.last()
146
156
  assert isinstance(run, Run)
147
157
  assert run.data.params["p"] == "5"
148
158
 
149
159
 
150
- def test_runs_last_empty(runs: RunCollection):
151
- runs._runs = []
160
+ def test_runs_last_empty(rc: RunCollection):
161
+ rc._runs = []
152
162
  with pytest.raises(ValueError):
153
- runs.last()
163
+ rc.last()
154
164
 
155
165
 
156
- def test_runs_try_last_none(runs: RunCollection):
157
- runs._runs = []
158
- assert runs.try_last() is None
166
+ def test_runs_try_last_none(rc: RunCollection):
167
+ rc._runs = []
168
+ assert rc.try_last() is None
159
169
 
160
170
 
161
- def test_runs_filter(runs: RunCollection):
162
- assert len(runs.filter()) == 6
163
- assert len(runs.filter({})) == 6
164
- assert len(runs.filter({"p": 1})) == 1
165
- assert len(runs.filter({"q": 0})) == 5
166
- assert len(runs.filter({"q": -1})) == 0
167
- assert len(runs.filter(p=5)) == 1
168
- assert len(runs.filter(q=0)) == 5
169
- assert len(runs.filter(q=-1)) == 0
170
- assert len(runs.filter({"r": 2})) == 2
171
- assert len(runs.filter(r=0)) == 2
171
+ def test_runs_filter(rc: RunCollection):
172
+ assert len(rc.filter()) == 6
173
+ assert len(rc.filter({})) == 6
174
+ assert len(rc.filter({"p": 1})) == 1
175
+ assert len(rc.filter({"q": 0})) == 5
176
+ assert len(rc.filter({"q": -1})) == 0
177
+ assert not rc.filter({"q": -1})
178
+ assert len(rc.filter(p=5)) == 1
179
+ assert len(rc.filter(q=0)) == 5
180
+ assert len(rc.filter(q=-1)) == 0
181
+ assert not rc.filter(q=-1)
182
+ assert len(rc.filter({"r": 2})) == 2
183
+ assert len(rc.filter(r=0)) == 2
172
184
 
173
185
 
174
- def test_runs_get(runs: RunCollection):
175
- run = runs.get({"p": 4})
186
+ def test_runs_get(rc: RunCollection):
187
+ run = rc.get({"p": 4})
176
188
  assert isinstance(run, Run)
177
- run = runs.get(p=2)
189
+ run = rc.get(p=2)
178
190
  assert isinstance(run, Run)
179
191
 
180
192
 
181
- def test_runs_try_get(runs: RunCollection):
182
- run = runs.try_get({"p": 5})
193
+ def test_runs_try_get(rc: RunCollection):
194
+ run = rc.try_get({"p": 5})
183
195
  assert isinstance(run, Run)
184
- run = runs.try_get(p=1)
196
+ run = rc.try_get(p=1)
185
197
  assert isinstance(run, Run)
186
- run = runs.try_get(p=-1)
198
+ run = rc.try_get(p=-1)
187
199
  assert run is None
188
200
 
189
201
 
190
- def test_runs_get_params_names(runs: RunCollection):
191
- names = runs.get_param_names()
202
+ def test_runs_get_params_names(rc: RunCollection):
203
+ names = rc.get_param_names()
192
204
  assert len(names) == 3
193
205
  assert "p" in names
194
206
  assert "q" in names
195
207
  assert "r" in names
196
208
 
197
209
 
198
- def test_runs_get_params_dict(runs: RunCollection):
199
- params = runs.get_param_dict()
210
+ def test_runs_get_params_dict(rc: RunCollection):
211
+ params = rc.get_param_dict()
200
212
  assert params["p"] == ["0", "1", "2", "3", "4", "5"]
201
213
  assert params["q"] == ["0", "None"]
202
214
  assert params["r"] == ["0", "1", "2"]
203
215
 
204
216
 
205
- def test_runs_find(runs: RunCollection):
206
- run = runs.find({"r": 0})
217
+ def test_runs_find(rc: RunCollection):
218
+ run = rc.find({"r": 0})
207
219
  assert isinstance(run, Run)
208
220
  assert run.data.params["p"] == "0"
209
- run = runs.find(r=2)
221
+ run = rc.find(r=2)
210
222
  assert isinstance(run, Run)
211
223
  assert run.data.params["p"] == "2"
212
224
 
213
225
 
214
- def test_runs_find_none(runs: RunCollection):
226
+ def test_runs_find_none(rc: RunCollection):
215
227
  with pytest.raises(ValueError):
216
- runs.find({"r": 10})
228
+ rc.find({"r": 10})
217
229
 
218
230
 
219
- def test_runs_try_find_none(runs: RunCollection):
220
- run = runs.try_find({"r": 10})
231
+ def test_runs_try_find_none(rc: RunCollection):
232
+ run = rc.try_find({"r": 10})
221
233
  assert run is None
222
234
 
223
235
 
224
- def test_runs_find_last(runs: RunCollection):
225
- run = runs.find_last({"r": 0})
236
+ def test_runs_find_last(rc: RunCollection):
237
+ run = rc.find_last({"r": 0})
226
238
  assert isinstance(run, Run)
227
239
  assert run.data.params["p"] == "3"
228
- run = runs.find_last(r=2)
240
+ run = rc.find_last(r=2)
229
241
  assert isinstance(run, Run)
230
242
  assert run.data.params["p"] == "5"
231
243
 
232
244
 
233
- def test_runs_find_last_none(runs: RunCollection):
245
+ def test_runs_find_last_none(rc: RunCollection):
234
246
  with pytest.raises(ValueError):
235
- runs.find_last({"p": 10})
247
+ rc.find_last({"p": 10})
236
248
 
237
249
 
238
- def test_runs_try_find_last_none(runs: RunCollection):
239
- run = runs.try_find_last({"p": 10})
250
+ def test_runs_try_find_last_none(rc: RunCollection):
251
+ run = rc.try_find_last({"p": 10})
240
252
  assert run is None
241
253
 
242
254
 
@@ -248,7 +260,7 @@ def runs2(monkeypatch, tmp_path):
248
260
  mlflow.log_param("x", x)
249
261
 
250
262
 
251
- def test_list_runs(runs, runs2):
263
+ def test_list_runs(rc, runs2):
252
264
  from hydraflow.mlflow import list_runs
253
265
 
254
266
  mlflow.set_experiment("test_run")
@@ -260,7 +272,7 @@ def test_list_runs(runs, runs2):
260
272
  assert len(all_runs) == 3
261
273
 
262
274
 
263
- def test_list_runs_empty_list(runs, runs2):
275
+ def test_list_runs_empty_list(rc, runs2):
264
276
  from hydraflow.mlflow import list_runs
265
277
 
266
278
  all_runs = list_runs([])
@@ -268,117 +280,118 @@ def test_list_runs_empty_list(runs, runs2):
268
280
 
269
281
 
270
282
  @pytest.mark.parametrize(["name", "n"], [("test_run", 6), ("test_run2", 3)])
271
- def test_list_runs_list(runs, runs2, name, n):
283
+ def test_list_runs_list(rc, runs2, name, n):
272
284
  from hydraflow.mlflow import list_runs
273
285
 
274
286
  filtered_runs = list_runs(name)
275
287
  assert len(filtered_runs) == n
276
288
 
277
289
 
278
- def test_list_runs_none(runs, runs2):
290
+ def test_list_runs_none(rc, runs2):
279
291
  from hydraflow.mlflow import list_runs
280
292
 
281
293
  no_runs = list_runs(["non_existent_experiment"])
282
294
  assert len(no_runs) == 0
295
+ assert not no_runs
283
296
 
284
297
 
285
- def test_run_collection_map(runs: RunCollection):
286
- results = list(runs.map(lambda run: run.info.run_id))
287
- assert len(results) == len(runs._runs)
298
+ def test_run_collection_map(rc: RunCollection):
299
+ results = list(rc.map(lambda run: run.info.run_id))
300
+ assert len(results) == len(rc._runs)
288
301
  assert all(isinstance(run_id, str) for run_id in results)
289
302
 
290
303
 
291
- def test_run_collection_map_args(runs: RunCollection):
292
- results = list(runs.map(lambda run, x: run.info.run_id + x, "test"))
304
+ def test_run_collection_map_args(rc: RunCollection):
305
+ results = list(rc.map(lambda run, x: run.info.run_id + x, "test"))
293
306
  assert all(x.endswith("test") for x in results)
294
307
 
295
308
 
296
- def test_run_collection_map_run_id(runs: RunCollection):
297
- results = list(runs.map_run_id(lambda run_id: run_id))
298
- assert len(results) == len(runs._runs)
309
+ def test_run_collection_map_run_id(rc: RunCollection):
310
+ results = list(rc.map_run_id(lambda run_id: run_id))
311
+ assert len(results) == len(rc._runs)
299
312
  assert all(isinstance(run_id, str) for run_id in results)
300
313
 
301
314
 
302
- def test_run_collection_map_run_id_kwargs(runs: RunCollection):
303
- results = list(runs.map_run_id(lambda run_id, x: x + run_id, x="test"))
315
+ def test_run_collection_map_run_id_kwargs(rc: RunCollection):
316
+ results = list(rc.map_run_id(lambda run_id, x: x + run_id, x="test"))
304
317
  assert all(x.startswith("test") for x in results)
305
318
 
306
319
 
307
- def test_run_collection_map_uri(runs: RunCollection):
308
- results = list(runs.map_uri(lambda uri: uri))
309
- assert len(results) == len(runs._runs)
320
+ def test_run_collection_map_uri(rc: RunCollection):
321
+ results = list(rc.map_uri(lambda uri: uri))
322
+ assert len(results) == len(rc._runs)
310
323
  assert all(isinstance(uri, str | type(None)) for uri in results)
311
324
 
312
325
 
313
- def test_run_collection_map_dir(runs: RunCollection):
314
- results = list(runs.map_dir(lambda dir_path, x: dir_path / x, "a.csv"))
315
- assert len(results) == len(runs._runs)
326
+ def test_run_collection_map_dir(rc: RunCollection):
327
+ results = list(rc.map_dir(lambda dir_path, x: dir_path / x, "a.csv"))
328
+ assert len(results) == len(rc._runs)
316
329
  assert all(isinstance(dir_path, Path) for dir_path in results)
317
330
  assert all(dir_path.stem == "a" for dir_path in results)
318
331
 
319
332
 
320
- def test_run_collection_sort(runs: RunCollection):
321
- runs.sort(key=lambda x: x.data.params["p"])
322
- assert [run.data.params["p"] for run in runs] == ["0", "1", "2", "3", "4", "5"]
333
+ def test_run_collection_sort(rc: RunCollection):
334
+ rc.sort(key=lambda x: x.data.params["p"])
335
+ assert [run.data.params["p"] for run in rc] == ["0", "1", "2", "3", "4", "5"]
323
336
 
324
- runs.sort(reverse=True)
325
- assert [run.data.params["p"] for run in runs] == ["5", "4", "3", "2", "1", "0"]
337
+ rc.sort(reverse=True)
338
+ assert [run.data.params["p"] for run in rc] == ["5", "4", "3", "2", "1", "0"]
326
339
 
327
340
 
328
- def test_run_collection_iter(runs: RunCollection):
329
- assert list(runs) == runs._runs
341
+ def test_run_collection_iter(rc: RunCollection):
342
+ assert list(rc) == rc._runs
330
343
 
331
344
 
332
345
  @pytest.mark.parametrize("i", range(6))
333
- def test_run_collection_getitem(runs: RunCollection, i: int):
334
- assert runs[i] == runs._runs[i]
346
+ def test_run_collection_getitem(rc: RunCollection, i: int):
347
+ assert rc[i] == rc._runs[i]
335
348
 
336
349
 
337
350
  @pytest.mark.parametrize("i", range(6))
338
- def test_run_collection_getitem_slice(runs: RunCollection, i: int):
339
- assert runs[i : i + 2]._runs == runs._runs[i : i + 2]
351
+ def test_run_collection_getitem_slice(rc: RunCollection, i: int):
352
+ assert rc[i : i + 2]._runs == rc._runs[i : i + 2]
340
353
 
341
354
 
342
355
  @pytest.mark.parametrize("i", range(6))
343
- def test_run_collection_getitem_slice_step(runs: RunCollection, i: int):
344
- assert runs[i::2]._runs == runs._runs[i::2]
356
+ def test_run_collection_getitem_slice_step(rc: RunCollection, i: int):
357
+ assert rc[i::2]._runs == rc._runs[i::2]
345
358
 
346
359
 
347
360
  @pytest.mark.parametrize("i", range(6))
348
- def test_run_collection_getitem_slice_step_neg(runs: RunCollection, i: int):
349
- assert runs[i::-2]._runs == runs._runs[i::-2]
361
+ def test_run_collection_getitem_slice_step_neg(rc: RunCollection, i: int):
362
+ assert rc[i::-2]._runs == rc._runs[i::-2]
350
363
 
351
364
 
352
- def test_run_collection_take(runs: RunCollection):
353
- assert runs.take(3)._runs == runs._runs[:3]
354
- assert len(runs.take(4)) == 4
355
- assert runs.take(10)._runs == runs._runs
365
+ def test_run_collection_take(rc: RunCollection):
366
+ assert rc.take(3)._runs == rc._runs[:3]
367
+ assert len(rc.take(4)) == 4
368
+ assert rc.take(10)._runs == rc._runs
356
369
 
357
370
 
358
- def test_run_collection_take_neg(runs: RunCollection):
359
- assert runs.take(-3)._runs == runs._runs[-3:]
360
- assert len(runs.take(-4)) == 4
361
- assert runs.take(-10)._runs == runs._runs
371
+ def test_run_collection_take_neg(rc: RunCollection):
372
+ assert rc.take(-3)._runs == rc._runs[-3:]
373
+ assert len(rc.take(-4)) == 4
374
+ assert rc.take(-10)._runs == rc._runs
362
375
 
363
376
 
364
377
  @pytest.mark.parametrize("i", range(6))
365
- def test_run_collection_contains(runs: RunCollection, i: int):
366
- assert runs[i] in runs
367
- assert runs._runs[i] in runs
378
+ def test_run_collection_contains(rc: RunCollection, i: int):
379
+ assert rc[i] in rc
380
+ assert rc._runs[i] in rc
368
381
 
369
382
 
370
- def test_run_collection_group_by(runs: RunCollection):
371
- grouped = runs.group_by(["p"])
383
+ def test_run_collection_group_by(rc: RunCollection):
384
+ grouped = rc.group_by(["p"])
372
385
  assert len(grouped) == 6
373
386
  assert all(isinstance(group, RunCollection) for group in grouped.values())
374
387
  assert all(len(group) == 1 for group in grouped.values())
375
- assert grouped[("0",)][0] == runs[0]
376
- assert grouped[("1",)][0] == runs[1]
388
+ assert grouped[("0",)][0] == rc[0]
389
+ assert grouped[("1",)][0] == rc[1]
377
390
 
378
- grouped = runs.group_by("q")
391
+ grouped = rc.group_by("q")
379
392
  assert len(grouped) == 2
380
393
 
381
- grouped = runs.group_by("r")
394
+ grouped = rc.group_by("r")
382
395
  assert len(grouped) == 3
383
396
 
384
397
 
@@ -396,24 +409,24 @@ def test_filter_runs_no_match(run_list: list[Run]):
396
409
  assert x == []
397
410
 
398
411
 
399
- def test_get_run_no_match(runs: RunCollection):
412
+ def test_get_run_no_match(rc: RunCollection):
400
413
  with pytest.raises(ValueError):
401
- runs.get({"p": 10})
414
+ rc.get({"p": 10})
402
415
 
403
416
 
404
- def test_get_run_multiple_params(runs: RunCollection):
405
- run = runs.get({"p": 4, "q": 0})
417
+ def test_get_run_multiple_params(rc: RunCollection):
418
+ run = rc.get({"p": 4, "q": 0})
406
419
  assert isinstance(run, Run)
407
420
  assert run.data.params["p"] == "4"
408
421
  assert run.data.params["q"] == "0"
409
422
 
410
423
 
411
- def test_try_get_run_no_match(runs: RunCollection):
412
- assert runs.try_get({"p": 10}) is None
424
+ def test_try_get_run_no_match(rc: RunCollection):
425
+ assert rc.try_get({"p": 10}) is None
413
426
 
414
427
 
415
- def test_try_get_run_multiple_params(runs: RunCollection):
416
- run = runs.try_get({"p": 4, "q": 0})
428
+ def test_try_get_run_multiple_params(rc: RunCollection):
429
+ run = rc.try_get({"p": 4, "q": 0})
417
430
  assert isinstance(run, Run)
418
431
  assert run.data.params["p"] == "4"
419
432
  assert run.data.params["q"] == "0"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes