hydraflow 0.4.5__tar.gz → 0.5.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. {hydraflow-0.4.5 → hydraflow-0.5.0}/.devcontainer/devcontainer.json +1 -0
  2. {hydraflow-0.4.5 → hydraflow-0.5.0}/.devcontainer/postCreate.sh +4 -3
  3. {hydraflow-0.4.5 → hydraflow-0.5.0}/PKG-INFO +3 -17
  4. {hydraflow-0.4.5 → hydraflow-0.5.0}/README.md +0 -10
  5. {hydraflow-0.4.5 → hydraflow-0.5.0}/pyproject.toml +9 -14
  6. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/__init__.py +3 -5
  7. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/context.py +1 -106
  8. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/run_collection.py +1 -1
  9. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/run_data.py +3 -3
  10. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/utils.py +16 -6
  11. {hydraflow-0.4.5/tests/scripts → hydraflow-0.5.0/tests/apps}/app.py +4 -14
  12. hydraflow-0.5.0/tests/conftest.py +16 -0
  13. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/test_app.py +17 -19
  14. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/test_context.py +2 -22
  15. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/test_log_run.py +13 -11
  16. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/test_mlflow.py +1 -1
  17. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/test_param.py +3 -1
  18. hydraflow-0.5.0/tests/test_run_collection.py +465 -0
  19. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/test_run_data.py +14 -17
  20. hydraflow-0.5.0/tests/test_run_info.py +42 -0
  21. hydraflow-0.5.0/tests/test_utils.py +32 -0
  22. hydraflow-0.4.5/src/hydraflow/asyncio.py +0 -227
  23. hydraflow-0.4.5/src/hydraflow/progress.py +0 -184
  24. hydraflow-0.4.5/tests/integ/app.py +0 -62
  25. hydraflow-0.4.5/tests/scripts/__init__.py +0 -0
  26. hydraflow-0.4.5/tests/scripts/progress.py +0 -65
  27. hydraflow-0.4.5/tests/scripts/watch.py +0 -9
  28. hydraflow-0.4.5/tests/test_asyncio.py +0 -221
  29. hydraflow-0.4.5/tests/test_progress.py +0 -12
  30. hydraflow-0.4.5/tests/test_run_collection.py +0 -487
  31. hydraflow-0.4.5/tests/test_run_info.py +0 -48
  32. hydraflow-0.4.5/tests/test_version.py +0 -5
  33. hydraflow-0.4.5/tests/test_watch.py +0 -28
  34. {hydraflow-0.4.5 → hydraflow-0.5.0}/.devcontainer/starship.toml +0 -0
  35. {hydraflow-0.4.5 → hydraflow-0.5.0}/.gitattributes +0 -0
  36. {hydraflow-0.4.5 → hydraflow-0.5.0}/.gitignore +0 -0
  37. {hydraflow-0.4.5 → hydraflow-0.5.0}/LICENSE +0 -0
  38. {hydraflow-0.4.5 → hydraflow-0.5.0}/apps/quickstart.py +0 -0
  39. {hydraflow-0.4.5 → hydraflow-0.5.0}/mkdocs.yml +0 -0
  40. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/config.py +0 -0
  41. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/mlflow.py +0 -0
  42. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/param.py +0 -0
  43. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/py.typed +0 -0
  44. {hydraflow-0.4.5 → hydraflow-0.5.0}/src/hydraflow/run_info.py +0 -0
  45. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/__init__.py +0 -0
  46. {hydraflow-0.4.5/tests/integ → hydraflow-0.5.0/tests/apps}/__init__.py +0 -0
  47. {hydraflow-0.4.5 → hydraflow-0.5.0}/tests/test_config.py +0 -0
@@ -8,6 +8,7 @@
8
8
  "extensions": [
9
9
  "charliermarsh.ruff",
10
10
  "fill-labs.dependi",
11
+ "markis.code-coverage",
11
12
  "ms-python.python",
12
13
  "ms-python.vscode-pylance",
13
14
  "tamasfe.even-better-toml"
@@ -1,10 +1,11 @@
1
1
  #!/bin/bash
2
2
 
3
3
  echo 'eval "$(starship init bash)"' >> ~/.bashrc
4
- echo "alias ll='ls -alF'" >> ~/.bashrc
5
4
  mkdir -p ~/.config
6
5
  cp .devcontainer/starship.toml ~/.config
7
6
 
8
7
  curl -LsSf https://astral.sh/uv/install.sh | sh
9
- source $HOME/.cargo/env
10
- echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc
8
+ echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc
9
+ uv tool install ruff@latest
10
+ uv python install
11
+ uv sync -U
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: hydraflow
3
- Version: 0.4.5
3
+ Version: 0.5.0
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -27,6 +27,7 @@ License: MIT License
27
27
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28
28
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29
29
  SOFTWARE.
30
+ License-File: LICENSE
30
31
  Classifier: Development Status :: 4 - Beta
31
32
  Classifier: License :: OSI Approved :: MIT License
32
33
  Classifier: Programming Language :: Python
@@ -36,12 +37,7 @@ Classifier: Programming Language :: Python :: 3.12
36
37
  Classifier: Programming Language :: Python :: 3.13
37
38
  Requires-Python: >=3.10
38
39
  Requires-Dist: hydra-core>=1.3
39
- Requires-Dist: joblib
40
40
  Requires-Dist: mlflow>=2.15
41
- Requires-Dist: polars
42
- Requires-Dist: rich
43
- Requires-Dist: watchdog
44
- Requires-Dist: watchfiles
45
41
  Description-Content-Type: text/markdown
46
42
 
47
43
  # Hydraflow
@@ -118,16 +114,6 @@ def my_app(cfg: MySQLConfig) -> None:
118
114
  with hydraflow.start_run():
119
115
  # Your app code below.
120
116
 
121
- with hydraflow.watch(callback):
122
- # Watch files in the MLflow artifact directory.
123
- # You can update metrics or log other artifacts
124
- # according to the watched files in your callback
125
- # function.
126
- pass
127
-
128
- # Your callback function here.
129
- def callback(file: Path) -> None:
130
- pass
131
117
 
132
118
  if __name__ == "__main__":
133
119
  my_app()
@@ -72,16 +72,6 @@ def my_app(cfg: MySQLConfig) -> None:
72
72
  with hydraflow.start_run():
73
73
  # Your app code below.
74
74
 
75
- with hydraflow.watch(callback):
76
- # Watch files in the MLflow artifact directory.
77
- # You can update metrics or log other artifacts
78
- # according to the watched files in your callback
79
- # function.
80
- pass
81
-
82
- # Your callback function here.
83
- def callback(file: Path) -> None:
84
- pass
85
75
 
86
76
  if __name__ == "__main__":
87
77
  my_app()
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.4.5"
7
+ version = "0.5.0"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -19,15 +19,7 @@ classifiers = [
19
19
  "Programming Language :: Python :: 3.13",
20
20
  ]
21
21
  requires-python = ">=3.10"
22
- dependencies = [
23
- "hydra-core>=1.3",
24
- "joblib",
25
- "mlflow>=2.15",
26
- "polars",
27
- "rich",
28
- "watchdog",
29
- "watchfiles",
30
- ]
22
+ dependencies = ["hydra-core>=1.3", "mlflow>=2.15"]
31
23
 
32
24
  [project.urls]
33
25
  Documentation = "https://github.com/daizutabi/hydraflow"
@@ -40,12 +32,10 @@ dev-dependencies = [
40
32
  "mkapi",
41
33
  "mkdocs-material",
42
34
  "mkdocs>=1.6",
43
- "pytest-asyncio",
44
35
  "pytest-clarity",
45
36
  "pytest-cov",
46
37
  "pytest-randomly",
47
38
  "pytest-xdist",
48
- "ruff",
49
39
  ]
50
40
 
51
41
  [tool.hatch.build.targets.sdist]
@@ -61,11 +51,15 @@ addopts = [
61
51
  "--cov-report=lcov:lcov.info",
62
52
  ]
63
53
  doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"]
64
- filterwarnings = ['ignore:pkg_resources is deprecated:DeprecationWarning']
65
- asyncio_default_fixture_loop_scope = "function"
54
+ filterwarnings = [
55
+ "ignore:pkg_resources is deprecated:DeprecationWarning",
56
+ "ignore:Support for class-based `config` is deprecated",
57
+ "ignore:Pydantic V1 style",
58
+ ]
66
59
 
67
60
  [tool.coverage.report]
68
61
  exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
62
+ skip_covered = true
69
63
 
70
64
  [tool.ruff]
71
65
  line-length = 88
@@ -75,6 +69,7 @@ target-version = "py310"
75
69
  select = ["ALL"]
76
70
  unfixable = ["F401"]
77
71
  ignore = [
72
+ "A005",
78
73
  "ANN003",
79
74
  "ANN401",
80
75
  "ARG002",
@@ -1,9 +1,8 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
3
  from .config import select_config, select_overrides
4
- from .context import chdir_artifact, chdir_hydra_output, log_run, start_run, watch
4
+ from .context import chdir_artifact, chdir_hydra_output, log_run, start_run
5
5
  from .mlflow import list_runs, search_runs, set_experiment
6
- from .progress import multi_tasks_progress, parallel_progress
7
6
  from .run_collection import RunCollection
8
7
  from .utils import (
9
8
  get_artifact_dir,
@@ -11,6 +10,7 @@ from .utils import (
11
10
  get_overrides,
12
11
  load_config,
13
12
  load_overrides,
13
+ remove_run,
14
14
  )
15
15
 
16
16
  __all__ = [
@@ -24,12 +24,10 @@ __all__ = [
24
24
  "load_config",
25
25
  "load_overrides",
26
26
  "log_run",
27
- "multi_tasks_progress",
28
- "parallel_progress",
27
+ "remove_run",
29
28
  "search_runs",
30
29
  "select_config",
31
30
  "select_overrides",
32
31
  "set_experiment",
33
32
  "start_run",
34
- "watch",
35
33
  ]
@@ -4,7 +4,6 @@ from __future__ import annotations
4
4
 
5
5
  import logging
6
6
  import os
7
- import time
8
7
  from contextlib import contextmanager
9
8
  from pathlib import Path
10
9
  from typing import TYPE_CHECKING
@@ -12,14 +11,11 @@ from typing import TYPE_CHECKING
12
11
  import mlflow
13
12
  import mlflow.artifacts
14
13
  from hydra.core.hydra_config import HydraConfig
15
- from watchdog.events import FileModifiedEvent, PatternMatchingEventHandler
16
- from watchdog.observers import Observer
17
14
 
18
15
  from hydraflow.mlflow import log_params
19
- from hydraflow.run_info import get_artifact_dir
20
16
 
21
17
  if TYPE_CHECKING:
22
- from collections.abc import Callable, Iterator
18
+ from collections.abc import Iterator
23
19
 
24
20
  from mlflow.entities.run import Run
25
21
 
@@ -64,14 +60,8 @@ def log_run(
64
60
  output_subdir = output_dir / (hc.output_subdir or "")
65
61
  mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
66
62
 
67
- def log_artifact(path: Path) -> None:
68
- local_path = (output_dir / path).as_posix()
69
- mlflow.log_artifact(local_path)
70
-
71
63
  try:
72
64
  yield
73
- # with watch(log_artifact, output_dir, ignore_log=False):
74
- # yield
75
65
 
76
66
  except Exception as e:
77
67
  msg = f"Error during log_run: {e}"
@@ -146,101 +136,6 @@ def start_run( # noqa: PLR0913
146
136
  yield run
147
137
 
148
138
 
149
- @contextmanager
150
- def watch(
151
- callback: Callable[[Path], None],
152
- dir: Path | str = "", # noqa: A002
153
- *,
154
- timeout: int = 60,
155
- ignore_patterns: list[str] | None = None,
156
- ignore_log: bool = True,
157
- ) -> Iterator[None]:
158
- """Watch the given directory for changes.
159
-
160
- This context manager sets up a file system watcher on the specified directory.
161
- When a file modification is detected, the provided function is called with
162
- the path of the modified file. The watcher runs for the specified timeout
163
- period or until the context is exited.
164
-
165
- Args:
166
- callback (Callable[[Path], None]): The function to call when a change is
167
- detected. It should accept a single argument of type `Path`,
168
- which is the path of the modified file.
169
- dir (Path | str): The directory to watch. If not specified,
170
- the current MLflow artifact URI is used. Defaults to "".
171
- timeout (int): The timeout period in seconds for the watcher
172
- to run after the context is exited. Defaults to 60.
173
- ignore_patterns (list[str] | None): A list of glob patterns to ignore.
174
- Defaults to None.
175
- ignore_log (bool): Whether to ignore log files. Defaults to True.
176
-
177
- Yields:
178
- None
179
-
180
- Example:
181
- ```python
182
- with watch(log_artifact, "/path/to/dir"):
183
- # Perform operations while watching the directory for changes
184
- pass
185
- ```
186
-
187
- """
188
- dir = dir or get_artifact_dir() # noqa: A001
189
- if isinstance(dir, Path):
190
- dir = dir.as_posix() # noqa: A001
191
-
192
- handler = Handler(callback, ignore_patterns=ignore_patterns, ignore_log=ignore_log)
193
- observer = Observer()
194
- observer.schedule(handler, dir, recursive=True)
195
- observer.start()
196
-
197
- try:
198
- yield
199
-
200
- except Exception as e:
201
- msg = f"Error during watch: {e}"
202
- log.exception(msg)
203
- raise
204
-
205
- finally:
206
- elapsed = 0
207
- while not observer.event_queue.empty():
208
- time.sleep(0.2)
209
- elapsed += 0.2
210
- if elapsed > timeout:
211
- break
212
-
213
- observer.stop()
214
- observer.join()
215
-
216
-
217
- class Handler(PatternMatchingEventHandler):
218
- """Monitor file changes and call the given function when a change is detected."""
219
-
220
- def __init__(
221
- self,
222
- func: Callable[[Path], None],
223
- *,
224
- ignore_patterns: list[str] | None = None,
225
- ignore_log: bool = True,
226
- ) -> None:
227
- self.func = func
228
-
229
- if ignore_log:
230
- if ignore_patterns:
231
- ignore_patterns.append("*.log")
232
- else:
233
- ignore_patterns = ["*.log"]
234
-
235
- super().__init__(ignore_patterns=ignore_patterns)
236
-
237
- def on_modified(self, event: FileModifiedEvent) -> None:
238
- """Modify when a file is modified."""
239
- file = Path(str(event.src_path))
240
- if file.is_file():
241
- self.func(file)
242
-
243
-
244
139
  @contextmanager
245
140
  def chdir_hydra_output() -> Iterator[Path]:
246
141
  """Change the current working directory to the hydra output directory.
@@ -575,7 +575,7 @@ class RunCollection:
575
575
  """
576
576
  return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
577
577
 
578
- def group_by(
578
+ def groupby(
579
579
  self,
580
580
  names: str | list[str],
581
581
  ) -> dict[str | None | tuple[str | None, ...], RunCollection]:
@@ -4,7 +4,7 @@ from __future__ import annotations
4
4
 
5
5
  from typing import TYPE_CHECKING
6
6
 
7
- from polars.dataframe import DataFrame
7
+ from pandas import DataFrame
8
8
 
9
9
  from hydraflow.config import collect_params
10
10
 
@@ -33,10 +33,10 @@ class RunCollectionData:
33
33
 
34
34
  @property
35
35
  def config(self) -> DataFrame:
36
- """Get the runs' configurations as a polars DataFrame.
36
+ """Get the runs' configurations as a DataFrame.
37
37
 
38
38
  Returns:
39
- A polars DataFrame containing the runs' configurations.
39
+ A DataFrame containing the runs' configurations.
40
40
 
41
41
  """
42
42
  return DataFrame(self._runs.map_config(collect_params))
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import shutil
5
6
  from pathlib import Path
6
7
  from typing import TYPE_CHECKING
7
8
 
@@ -9,11 +10,10 @@ import mlflow
9
10
  import mlflow.artifacts
10
11
  from hydra.core.hydra_config import HydraConfig
11
12
  from mlflow.entities import Run
12
- from mlflow.tracking import artifact_utils
13
13
  from omegaconf import DictConfig, OmegaConf
14
14
 
15
15
  if TYPE_CHECKING:
16
- from mlflow.entities import Run
16
+ from collections.abc import Iterable
17
17
 
18
18
 
19
19
  def get_artifact_dir(run: Run | None = None) -> Path:
@@ -28,10 +28,10 @@ def get_artifact_dir(run: Run | None = None) -> Path:
28
28
  The local path to the directory where the artifacts are downloaded.
29
29
 
30
30
  """
31
- if run is None:
32
- uri = mlflow.get_artifact_uri()
33
- else:
34
- uri = artifact_utils.get_artifact_uri(run.info.run_id)
31
+ uri = mlflow.get_artifact_uri() if run is None else run.info.artifact_uri
32
+
33
+ if not (isinstance(uri, str) and uri.startswith("file://")):
34
+ raise NotImplementedError
35
35
 
36
36
  return Path(mlflow.artifacts.download_artifacts(uri))
37
37
 
@@ -112,3 +112,13 @@ def load_overrides(run: Run) -> list[str]:
112
112
  """
113
113
  path = get_artifact_dir(run) / ".hydra/overrides.yaml"
114
114
  return [str(x) for x in OmegaConf.load(path)]
115
+
116
+
117
+ def remove_run(run: Run | Iterable[Run]) -> None:
118
+ """Remove the given run from the MLflow tracking server."""
119
+ if not isinstance(run, Run):
120
+ for r in run:
121
+ remove_run(r)
122
+ return
123
+
124
+ shutil.rmtree(get_artifact_dir(run).parent)
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- import time
5
4
  from dataclasses import dataclass, field
6
5
  from pathlib import Path
7
6
 
@@ -41,8 +40,9 @@ def app(cfg: MySQLConfig):
41
40
  assert cfg.get("values") == [1, 2, 3] # type: ignore
42
41
 
43
42
  hydraflow.set_experiment(prefix="_", suffix="_")
44
- with hydraflow.start_run(cfg) as run:
45
- log.info(f"START, {cfg.host}, {cfg.port} ")
43
+ with hydraflow.start_run(cfg, synchronous=True) as run:
44
+ msg = f"START, {cfg.host}, {cfg.port} "
45
+ log.info(msg)
46
46
 
47
47
  artifact_dir = hydraflow.get_artifact_dir()
48
48
  output_dir = hydraflow.get_hydra_output_dir()
@@ -52,10 +52,6 @@ def app(cfg: MySQLConfig):
52
52
  mlflow.log_text("A " + artifact_dir.as_posix(), "artifact_dir.txt")
53
53
  mlflow.log_text("B " + output_dir.as_posix(), "output_dir.txt")
54
54
 
55
- # with hydraflow.watch(callback, ignore_patterns=["b.txt"]):
56
- # (artifact_dir / "a.txt").write_text("abc")
57
- # time.sleep(0.1)
58
-
59
55
  (artifact_dir / "a.txt").write_text("abc")
60
56
 
61
57
  mlflow.log_metric("m", cfg.port + 1, 1)
@@ -65,16 +61,10 @@ def app(cfg: MySQLConfig):
65
61
  assert hydraflow.get_overrides() == hydraflow.load_overrides(run)
66
62
 
67
63
  if cfg.host == "error":
68
- raise Exception("error")
64
+ raise Exception("error") # noqa: TRY002
69
65
 
70
66
  log.info("END")
71
67
 
72
68
 
73
- def callback(path: Path):
74
- log.info(f"WATCH, {path.as_posix()}")
75
- m = len(path.read_text()) # len("abc") == 3
76
- # mlflow.log_metric("watch", m, 1, synchronous=True)
77
-
78
-
79
69
  if __name__ == "__main__":
80
70
  app()
@@ -0,0 +1,16 @@
1
+ import os
2
+ import uuid
3
+ from pathlib import Path
4
+
5
+ import mlflow
6
+ import pytest
7
+
8
+
9
+ @pytest.fixture(scope="module")
10
+ def experiment_name(tmp_path_factory: pytest.TempPathFactory):
11
+ cwd = Path.cwd()
12
+ name = str(uuid.uuid4())
13
+ os.chdir(tmp_path_factory.mktemp(name))
14
+ mlflow.set_experiment(name)
15
+ yield name
16
+ os.chdir(cwd)
@@ -1,27 +1,24 @@
1
- from __future__ import annotations
2
-
1
+ import os
3
2
  import subprocess
4
3
  import sys
5
4
  from pathlib import Path
6
- from typing import TYPE_CHECKING
7
5
 
8
6
  import mlflow
9
7
  import pytest
10
8
  from mlflow.entities import RunStatus
11
- from omegaconf import OmegaConf
12
-
13
- if TYPE_CHECKING:
14
- from omegaconf import DictConfig
9
+ from omegaconf import DictConfig, OmegaConf
15
10
 
16
- from hydraflow.run_collection import RunCollection
11
+ from hydraflow.run_collection import RunCollection
17
12
 
18
13
 
19
- @pytest.fixture
20
- def rc(monkeypatch, tmp_path):
14
+ @pytest.fixture(scope="module")
15
+ def rc(tmp_path_factory: pytest.TempPathFactory):
21
16
  import hydraflow
22
17
 
23
- file = Path("tests/scripts/app.py").absolute()
24
- monkeypatch.chdir(tmp_path)
18
+ cwd = Path.cwd()
19
+
20
+ file = Path("tests/apps/app.py").absolute()
21
+ os.chdir(tmp_path_factory.mktemp("test_app"))
25
22
 
26
23
  args = [sys.executable, file.as_posix(), "-m"]
27
24
  args += ["host=x,y", "port=1,2", "hydra.job.name=info"]
@@ -30,6 +27,8 @@ def rc(monkeypatch, tmp_path):
30
27
  mlflow.set_experiment("_info_")
31
28
  yield hydraflow.list_runs()
32
29
 
30
+ os.chdir(cwd)
31
+
33
32
 
34
33
  def test_list_runs_all(rc: RunCollection):
35
34
  from hydraflow.mlflow import list_runs
@@ -112,7 +111,7 @@ def test_app_data_config(rc: RunCollection):
112
111
  def test_app_data_config_list(rc: RunCollection):
113
112
  config = rc.data.config
114
113
  values = config["values"].to_list()
115
- assert str(config.select("values").dtypes) == "[List(Int64)]"
114
+ assert str(config["values"].dtypes) == "object"
116
115
  for x in values:
117
116
  assert isinstance(x, list)
118
117
  assert x == [1, 2, 3]
@@ -154,8 +153,8 @@ def test_app_map_config(rc: RunCollection):
154
153
  assert ports == [2, 3, 2, 3]
155
154
 
156
155
 
157
- def test_app_group_by(rc: RunCollection):
158
- grouped = rc.group_by("host")
156
+ def test_app_groupby(rc: RunCollection):
157
+ grouped = rc.groupby("host")
159
158
  assert len(grouped) == 2
160
159
  assert grouped["x"].data.params["port"] == ["1", "2"]
161
160
  assert grouped["x"].data.params["host"] == ["x", "x"]
@@ -165,8 +164,8 @@ def test_app_group_by(rc: RunCollection):
165
164
  assert grouped["y"].data.params["values"] == ["[1, 2, 3]", "[1, 2, 3]"]
166
165
 
167
166
 
168
- def test_app_group_by_list(rc: RunCollection):
169
- grouped = rc.group_by(["host"])
167
+ def test_app_groupby_list(rc: RunCollection):
168
+ grouped = rc.groupby(["host"])
170
169
  assert len(grouped) == 2
171
170
  assert ("x",) in grouped
172
171
  assert ("y",) in grouped
@@ -203,8 +202,7 @@ def test_log_run_error(monkeypatch, tmp_path):
203
202
  args = [sys.executable, file.as_posix()]
204
203
  args += ["host=error", "hydra.job.name=error"]
205
204
  cp = subprocess.run(args, check=False, capture_output=True)
206
- assert cp.returncode == 1
207
- assert b"Error during log_run: error" in cp.stdout
205
+ assert cp.returncode
208
206
 
209
207
 
210
208
  def test_chdir_artifact(rc: RunCollection):
@@ -1,16 +1,15 @@
1
- import time
2
1
  from pathlib import Path
3
2
  from unittest.mock import MagicMock, patch
4
3
 
5
4
  import mlflow
6
5
  import pytest
7
6
 
8
- from hydraflow.context import log_run, start_run, watch
7
+ from hydraflow.context import log_run, start_run
9
8
  from hydraflow.run_collection import RunCollection
10
9
 
11
10
 
12
11
  @pytest.fixture
13
- def runs(monkeypatch, tmp_path):
12
+ def runs(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
14
13
  from hydraflow.mlflow import list_runs
15
14
 
16
15
  monkeypatch.chdir(tmp_path)
@@ -67,22 +66,3 @@ def test_log_run_error_handling(tmp_path: Path):
67
66
  with pytest.raises(Exception, match="Test exception"):
68
67
  with log_run(config):
69
68
  pass
70
-
71
-
72
- def test_watch_context_manager(tmp_path: Path):
73
- test_dir = tmp_path / "test_watch"
74
- test_dir.mkdir(parents=True, exist_ok=True)
75
- test_file = test_dir / "test_file.txt"
76
-
77
- called = []
78
-
79
- def mock_func(path: Path):
80
- assert path == test_file
81
- called.append(path)
82
-
83
- with watch(mock_func, test_dir):
84
- test_file.write_text("new content")
85
- time.sleep(1)
86
-
87
- assert len(called) == 1
88
- assert called[0] == test_file
@@ -1,5 +1,4 @@
1
- from __future__ import annotations
2
-
1
+ import os
3
2
  import subprocess
4
3
  import sys
5
4
  from pathlib import Path
@@ -10,10 +9,12 @@ from mlflow.artifacts import download_artifacts
10
9
  from mlflow.entities.run import Run
11
10
 
12
11
 
13
- @pytest.fixture
14
- def runs(monkeypatch, tmp_path):
15
- file = Path("tests/scripts/app.py").absolute()
16
- monkeypatch.chdir(tmp_path)
12
+ @pytest.fixture(scope="module")
13
+ def runs(tmp_path_factory: pytest.TempPathFactory):
14
+ file = Path("tests/apps/app.py").absolute()
15
+
16
+ cwd = Path.cwd()
17
+ os.chdir(tmp_path_factory.mktemp("test_log_run"))
17
18
 
18
19
  args = [sys.executable, file.as_posix(), "-m"]
19
20
  args += ["host=x,y", "port=1,2", "hydra.job.name=log_run"]
@@ -21,16 +22,17 @@ def runs(monkeypatch, tmp_path):
21
22
 
22
23
  mlflow.set_experiment("_log_run_")
23
24
  runs = mlflow.search_runs(output_format="list")
25
+
24
26
  assert len(runs) == 4
25
27
  assert isinstance(runs, list)
26
28
  yield runs
27
29
 
30
+ os.chdir(cwd)
31
+
28
32
 
29
- @pytest.fixture(params=range(4))
30
- def run(runs, request):
31
- run = runs[request.param] # type: ignore
32
- assert isinstance(run, Run)
33
- return run
33
+ @pytest.fixture(scope="module", params=range(4))
34
+ def run(runs: list[Run], request: pytest.FixtureRequest):
35
+ return runs[request.param]
34
36
 
35
37
 
36
38
  @pytest.fixture
@@ -5,7 +5,7 @@ from hydra.core.hydra_config import HydraConfig
5
5
 
6
6
 
7
7
  @pytest.fixture
8
- def hydra_config(monkeypatch):
8
+ def hydra_config(monkeypatch: pytest.MonkeyPatch):
9
9
  class MockJob:
10
10
  name = "test_job"
11
11
 
@@ -1,9 +1,11 @@
1
+ from pathlib import Path
2
+
1
3
  import mlflow
2
4
  import pytest
3
5
 
4
6
 
5
7
  @pytest.fixture
6
- def param(monkeypatch, tmp_path):
8
+ def param(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
7
9
  def param(value):
8
10
  monkeypatch.chdir(tmp_path)
9
11
  mlflow.set_experiment("test_param")