hydraflow 0.2.5__tar.gz → 0.2.7__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (31) hide show
  1. {hydraflow-0.2.5 → hydraflow-0.2.7}/PKG-INFO +3 -1
  2. hydraflow-0.2.7/mlruns/0/meta.yaml +6 -0
  3. {hydraflow-0.2.5 → hydraflow-0.2.7}/pyproject.toml +3 -1
  4. {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/__init__.py +2 -2
  5. hydraflow-0.2.7/src/hydraflow/info.py +63 -0
  6. {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/mlflow.py +30 -35
  7. hydraflow-0.2.7/src/hydraflow/progress.py +131 -0
  8. hydraflow-0.2.5/src/hydraflow/runs.py → hydraflow-0.2.7/src/hydraflow/run_collection.py +133 -82
  9. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_asyncio.py +1 -0
  10. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_context.py +22 -14
  11. hydraflow-0.2.7/tests/test_info.py +51 -0
  12. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_log_run.py +1 -1
  13. hydraflow-0.2.7/tests/test_progress.py +12 -0
  14. hydraflow-0.2.5/tests/test_runs.py → hydraflow-0.2.7/tests/test_run_collection.py +92 -40
  15. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_watch.py +4 -2
  16. {hydraflow-0.2.5 → hydraflow-0.2.7}/.devcontainer/devcontainer.json +0 -0
  17. {hydraflow-0.2.5 → hydraflow-0.2.7}/.devcontainer/postCreate.sh +0 -0
  18. {hydraflow-0.2.5 → hydraflow-0.2.7}/.devcontainer/starship.toml +0 -0
  19. {hydraflow-0.2.5 → hydraflow-0.2.7}/.gitattributes +0 -0
  20. {hydraflow-0.2.5 → hydraflow-0.2.7}/.gitignore +0 -0
  21. {hydraflow-0.2.5 → hydraflow-0.2.7}/LICENSE +0 -0
  22. {hydraflow-0.2.5 → hydraflow-0.2.7}/README.md +0 -0
  23. {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/asyncio.py +0 -0
  24. {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/config.py +0 -0
  25. {hydraflow-0.2.5 → hydraflow-0.2.7}/src/hydraflow/context.py +0 -0
  26. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/scripts/__init__.py +0 -0
  27. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/scripts/log_run.py +0 -0
  28. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/scripts/watch.py +0 -0
  29. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_config.py +0 -0
  30. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_mlflow.py +0 -0
  31. {hydraflow-0.2.5 → hydraflow-0.2.7}/tests/test_version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.5
3
+ Version: 0.2.7
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -17,7 +17,9 @@ Classifier: Topic :: Documentation
17
17
  Classifier: Topic :: Software Development :: Documentation
18
18
  Requires-Python: >=3.10
19
19
  Requires-Dist: hydra-core>1.3
20
+ Requires-Dist: joblib
20
21
  Requires-Dist: mlflow>2.15
22
+ Requires-Dist: rich
21
23
  Requires-Dist: setuptools
22
24
  Requires-Dist: watchdog
23
25
  Requires-Dist: watchfiles
@@ -0,0 +1,6 @@
1
+ artifact_location: file:///workspaces/hydraflow/mlruns/0
2
+ creation_time: 1725536713011
3
+ experiment_id: '0'
4
+ last_update_time: 1725536713011
5
+ lifecycle_stage: active
6
+ name: Default
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.5"
7
+ version = "0.2.7"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -21,7 +21,9 @@ classifiers = [
21
21
  requires-python = ">=3.10"
22
22
  dependencies = [
23
23
  "hydra-core>1.3",
24
+ "joblib",
24
25
  "mlflow>2.15",
26
+ "rich",
25
27
  "setuptools",
26
28
  "watchdog",
27
29
  "watchfiles",
@@ -1,9 +1,9 @@
1
1
  from .context import chdir_artifact, log_run, start_run, watch
2
+ from .info import load_config
2
3
  from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
3
- from .runs import (
4
+ from .run_collection import (
4
5
  RunCollection,
5
6
  list_runs,
6
- load_config,
7
7
  search_runs,
8
8
  )
9
9
 
@@ -0,0 +1,63 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from omegaconf import DictConfig, OmegaConf
6
+
7
+ from hydraflow.mlflow import get_artifact_dir
8
+
9
+ if TYPE_CHECKING:
10
+ from pathlib import Path
11
+
12
+ from mlflow.entities import Run
13
+
14
+ from hydraflow.run_collection import RunCollection
15
+
16
+
17
+ class RunCollectionInfo:
18
+ def __init__(self, runs: RunCollection):
19
+ self._runs = runs
20
+
21
+ @property
22
+ def run_id(self) -> list[str]:
23
+ return [run.info.run_id for run in self._runs]
24
+
25
+ @property
26
+ def params(self) -> list[dict[str, str]]:
27
+ return [run.data.params for run in self._runs]
28
+
29
+ @property
30
+ def metrics(self) -> list[dict[str, float]]:
31
+ return [run.data.metrics for run in self._runs]
32
+
33
+ @property
34
+ def artifact_uri(self) -> list[str | None]:
35
+ return [run.info.artifact_uri for run in self._runs]
36
+
37
+ @property
38
+ def artifact_dir(self) -> list[Path]:
39
+ return [get_artifact_dir(run) for run in self._runs]
40
+
41
+ @property
42
+ def config(self) -> list[DictConfig]:
43
+ return [load_config(run) for run in self._runs]
44
+
45
+
46
+ def load_config(run: Run) -> DictConfig:
47
+ """
48
+ Load the configuration for a given run.
49
+
50
+ This function loads the configuration for the provided Run instance
51
+ by downloading the configuration file from the MLflow artifacts and
52
+ loading it using OmegaConf. It returns an empty config if
53
+ `.hydra/config.yaml` is not found in the run's artifact directory.
54
+
55
+ Args:
56
+ run (Run): The Run instance for which to load the configuration.
57
+
58
+ Returns:
59
+ The loaded configuration as a DictConfig object. Returns an empty
60
+ DictConfig if the configuration file is not found.
61
+ """
62
+ path = get_artifact_dir(run) / ".hydra/config.yaml"
63
+ return OmegaConf.load(path) # type: ignore
@@ -17,6 +17,7 @@ from hydraflow.config import iter_params
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from mlflow.entities.experiment import Experiment
20
+ from mlflow.entities.run import Run
20
21
 
21
22
 
22
23
  def set_experiment(
@@ -65,60 +66,54 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
65
66
  mlflow.log_param(key, value, synchronous=synchronous)
66
67
 
67
68
 
68
- def get_artifact_dir(
69
- artifact_path: str | None = None,
70
- *,
71
- run_id: str | None = None,
72
- ) -> Path:
69
+ def get_artifact_dir(run: Run | None = None) -> Path:
73
70
  """
74
- Get the artifact directory for the given artifact path.
71
+ Retrieve the artifact directory for the given run.
75
72
 
76
- This function retrieves the artifact URI for the specified artifact path
77
- using MLflow, downloads the artifacts to a local directory, and returns
78
- the path to that directory.
73
+ This function uses MLflow to get the artifact directory for the given run.
79
74
 
80
75
  Args:
81
- artifact_path (str | None): The artifact path for which to get the
82
- directory. Defaults to None.
83
- run_id (str | None): The run ID for which to get the artifact directory.
76
+ run (Run | None): The run object. Defaults to None.
84
77
 
85
78
  Returns:
86
79
  The local path to the directory where the artifacts are downloaded.
87
80
  """
88
- if run_id is None:
89
- uri = mlflow.get_artifact_uri(artifact_path)
81
+ if run is None:
82
+ uri = mlflow.get_artifact_uri()
90
83
  else:
91
- uri = artifact_utils.get_artifact_uri(run_id, artifact_path)
84
+ uri = artifact_utils.get_artifact_uri(run.info.run_id)
92
85
 
93
- dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
86
+ return Path(mlflow.artifacts.download_artifacts(uri))
94
87
 
95
- return Path(dir)
96
88
 
89
+ def get_hydra_output_dir(*, run: Run | None = None) -> Path:
90
+ """
91
+ Retrieve the Hydra output directory for the given run.
92
+
93
+ This function returns the Hydra output directory. If no run is provided,
94
+ it retrieves the output directory from the current Hydra configuration.
95
+ If a run is provided, it retrieves the artifact path for the run, loads
96
+ the Hydra configuration from the downloaded artifacts, and returns the
97
+ output directory specified in that configuration.
98
+
99
+ Args:
100
+ run (Run | None): The run object. Defaults to None.
101
+
102
+ Returns:
103
+ Path: The path to the Hydra output directory.
97
104
 
98
- def get_hydra_output_dir(*, run_id: str | None = None) -> Path:
99
- if run_id is None:
105
+ Raises:
106
+ FileNotFoundError: If the Hydra configuration file is not found
107
+ in the artifacts.
108
+ """
109
+ if run is None:
100
110
  hc = HydraConfig.get()
101
111
  return Path(hc.runtime.output_dir)
102
112
 
103
- path = get_artifact_dir(run_id=run_id) / ".hydra/hydra.yaml"
113
+ path = get_artifact_dir(run) / ".hydra/hydra.yaml"
104
114
 
105
115
  if path.exists():
106
116
  hc = OmegaConf.load(path)
107
117
  return Path(hc.hydra.runtime.output_dir)
108
118
 
109
119
  raise FileNotFoundError
110
-
111
-
112
- # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
113
- # """
114
- # Log the Hydra output directory.
115
-
116
- # Args:
117
- # run: The run object.
118
-
119
- # Returns:
120
- # None
121
- # """
122
- # output_dir = get_hydra_output_dir(run)
123
- # run_id = run if isinstance(run, str) else run.info.run_id
124
- # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
@@ -0,0 +1,131 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import joblib
6
+ from rich.progress import Progress
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Iterable
10
+
11
+ from rich.progress import ProgressColumn
12
+
13
+
14
+ def multi_task_progress(
15
+ iterables: Iterable[Iterable[int | tuple[int, int]]],
16
+ *columns: ProgressColumn | str,
17
+ n_jobs: int = -1,
18
+ description: str = "#{:0>3}",
19
+ main_description: str = "main",
20
+ transient: bool | None = None,
21
+ **kwargs,
22
+ ) -> None:
23
+ """
24
+ Render auto-updating progress bars for multiple tasks concurrently.
25
+
26
+ Args:
27
+ iterables (Iterable[Iterable[int | tuple[int, int]]]): A collection of
28
+ iterables, each representing a task. Each iterable can yield
29
+ integers (completed) or tuples of integers (completed, total).
30
+ *columns (ProgressColumn | str): Additional columns to display in the
31
+ progress bars.
32
+ n_jobs (int, optional): Number of jobs to run in parallel. Defaults to
33
+ -1, which means using all processors.
34
+ description (str, optional): Format string for describing tasks. Defaults to
35
+ "#{:0>3}".
36
+ main_description (str, optional): Description for the main task.
37
+ Defaults to "main".
38
+ transient (bool | None, optional): Whether to remove the progress bar
39
+ after completion. Defaults to None.
40
+ **kwargs: Additional keyword arguments passed to the Progress instance.
41
+
42
+ Returns:
43
+ None
44
+ """
45
+ if not columns:
46
+ columns = Progress.get_default_columns()
47
+
48
+ iterables = list(iterables)
49
+
50
+ with Progress(*columns, transient=transient or False, **kwargs) as progress:
51
+ n = len(iterables)
52
+
53
+ task_main = progress.add_task(main_description, total=None) if n > 1 else None
54
+ tasks = [
55
+ progress.add_task(description.format(i), start=False, total=None) for i in range(n)
56
+ ]
57
+
58
+ total = {}
59
+ completed = {}
60
+
61
+ def func(i: int) -> None:
62
+ completed[i] = 0
63
+ total[i] = None
64
+ progress.start_task(tasks[i])
65
+
66
+ for index in iterables[i]:
67
+ if isinstance(index, tuple):
68
+ completed[i], total[i] = index[0] + 1, index[1]
69
+ else:
70
+ completed[i] = index + 1
71
+
72
+ progress.update(tasks[i], total=total[i], completed=completed[i])
73
+ if task_main is not None:
74
+ if all(t is not None for t in total.values()):
75
+ t = sum(total.values())
76
+ else:
77
+ t = None
78
+ c = sum(completed.values())
79
+ progress.update(task_main, total=t, completed=c)
80
+
81
+ if transient or n > 1:
82
+ progress.remove_task(tasks[i])
83
+
84
+ if n > 1:
85
+ it = (joblib.delayed(func)(i) for i in range(n))
86
+ joblib.Parallel(n_jobs, prefer="threads")(it)
87
+
88
+ else:
89
+ func(0)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ import random
94
+ import time
95
+
96
+ from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn, TimeElapsedColumn
97
+
98
+ from hydraflow.progress import multi_task_progress
99
+
100
+ def task(total):
101
+ for i in range(total or 90):
102
+ if total is None:
103
+ yield i
104
+ else:
105
+ yield i, total
106
+ time.sleep(random.random() / 30)
107
+
108
+ def multi_task_progress_test(unknown_total: bool):
109
+ tasks = [task(random.randint(80, 100)) for _ in range(4)]
110
+ if unknown_total:
111
+ tasks = [task(None), *tasks, task(None)]
112
+
113
+ columns = [
114
+ SpinnerColumn(),
115
+ *Progress.get_default_columns(),
116
+ MofNCompleteColumn(),
117
+ TimeElapsedColumn(),
118
+ ]
119
+
120
+ kwargs = {}
121
+ if unknown_total:
122
+ kwargs["main_description"] = "unknown"
123
+
124
+ multi_task_progress(tasks, *columns, n_jobs=4, **kwargs)
125
+
126
+ multi_task_progress_test(False)
127
+ multi_task_progress_test(True)
128
+ multi_task_progress([task(100)])
129
+ multi_task_progress([task(None)], description="unknown")
130
+ multi_task_progress([task(100), task(None)], main_description="transient", transient=True)
131
+ multi_task_progress([task(100)], description="transient", transient=True)
@@ -6,24 +6,25 @@ runs, retrieve run information, log artifacts, and load configurations.
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
- from dataclasses import dataclass
10
- from functools import cache
9
+ from dataclasses import dataclass, field
11
10
  from itertools import chain
12
- from typing import TYPE_CHECKING, Any, TypeVar
11
+ from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar
13
12
 
14
13
  import mlflow
15
- from mlflow.artifacts import download_artifacts
16
14
  from mlflow.entities import ViewType
17
15
  from mlflow.entities.run import Run
18
16
  from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS
19
- from omegaconf import DictConfig, OmegaConf
20
17
 
21
18
  from hydraflow.config import iter_params
19
+ from hydraflow.info import RunCollectionInfo
22
20
 
23
21
  if TYPE_CHECKING:
24
22
  from collections.abc import Callable, Iterator
23
+ from pathlib import Path
25
24
  from typing import Any
26
25
 
26
+ from omegaconf import DictConfig
27
+
27
28
 
28
29
  def search_runs(
29
30
  experiment_ids: list[str] | None = None,
@@ -51,13 +52,6 @@ def search_runs(
51
52
  error if ``experiment_names`` is also not ``None`` or ``[]``.
52
53
  ``None`` will default to the active experiment if ``experiment_names``
53
54
  is ``None`` or ``[]``.
54
- experiment_ids (list[str] | None): List of experiment IDs. Search can
55
- work with experiment IDs or experiment names, but not both in the
56
- same call. Values other than ``None`` or ``[]`` will result in
57
- error if ``experiment_names`` is also not ``None`` or ``[]``.
58
- ``experiment_names`` is also not ``None`` or ``[]``. ``None`` will
59
- default to the active experiment if ``experiment_names`` is ``None``
60
- or ``[]``.
61
55
  filter_string (str): Filter query string, defaults to searching all
62
56
  runs.
63
57
  run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
@@ -128,6 +122,7 @@ def list_runs(experiment_names: list[str] | None = None) -> RunCollection:
128
122
 
129
123
 
130
124
  T = TypeVar("T")
125
+ P = ParamSpec("P")
131
126
 
132
127
 
133
128
  @dataclass
@@ -142,6 +137,12 @@ class RunCollection:
142
137
  _runs: list[Run]
143
138
  """A list of MLflow Run objects."""
144
139
 
140
+ _info: RunCollectionInfo = field(init=False)
141
+ """A list of MLflow Run objects."""
142
+
143
+ def __post_init__(self):
144
+ self._info = RunCollectionInfo(self)
145
+
145
146
  def __repr__(self) -> str:
146
147
  return f"{self.__class__.__name__}({len(self)})"
147
148
 
@@ -157,6 +158,10 @@ class RunCollection:
157
158
  def __contains__(self, run: Run) -> bool:
158
159
  return run in self._runs
159
160
 
161
+ @property
162
+ def info(self) -> RunCollectionInfo:
163
+ return self._info
164
+
160
165
  def sort(
161
166
  self,
162
167
  key: Callable[[Run], Any] | None = None,
@@ -418,52 +423,81 @@ class RunCollection:
418
423
  """
419
424
  return get_param_dict(self._runs)
420
425
 
421
- def map(self, func: Callable[[Run], T]) -> Iterator[T]:
426
+ def map(
427
+ self,
428
+ func: Callable[Concatenate[Run, P], T],
429
+ *args: P.args,
430
+ **kwargs: P.kwargs,
431
+ ) -> Iterator[T]:
422
432
  """
423
433
  Apply a function to each run in the collection and return an iterator of
424
434
  results.
425
435
 
436
+ This method iterates over each run in the collection and applies the
437
+ provided function to it, along with any additional arguments and
438
+ keyword arguments.
439
+
426
440
  Args:
427
- func (Callable[[Run], T]): A function that takes a run and returns a
428
- result.
441
+ func (Callable[[Run, P], T]): A function that takes a run and
442
+ additional arguments and returns a result.
443
+ *args: Additional arguments to pass to the function.
444
+ **kwargs: Additional keyword arguments to pass to the function.
429
445
 
430
446
  Yields:
431
- Results obtained by applying the function to each run in the
432
- collection.
447
+ Results obtained by applying the function to each run in the collection.
433
448
  """
434
- return (func(run) for run in self._runs)
449
+ return (func(run, *args, **kwargs) for run in self)
435
450
 
436
- def map_run_id(self, func: Callable[[str], T]) -> Iterator[T]:
451
+ def map_run_id(
452
+ self,
453
+ func: Callable[Concatenate[str, P], T],
454
+ *args: P.args,
455
+ **kwargs: P.kwargs,
456
+ ) -> Iterator[T]:
437
457
  """
438
458
  Apply a function to each run id in the collection and return an iterator
439
459
  of results.
440
460
 
441
461
  Args:
442
- func (Callable[[str], T]): A function that takes a run id and returns a
462
+ func (Callable[[str, P], T]): A function that takes a run id and returns a
443
463
  result.
464
+ *args: Additional arguments to pass to the function.
465
+ **kwargs: Additional keyword arguments to pass to the function.
444
466
 
445
467
  Yields:
446
468
  Results obtained by applying the function to each run id in the
447
469
  collection.
448
470
  """
449
- return (func(run.info.run_id) for run in self._runs)
471
+ return (func(run_id, *args, **kwargs) for run_id in self.info.run_id)
450
472
 
451
- def map_config(self, func: Callable[[DictConfig], T]) -> Iterator[T]:
473
+ def map_config(
474
+ self,
475
+ func: Callable[Concatenate[DictConfig, P], T],
476
+ *args: P.args,
477
+ **kwargs: P.kwargs,
478
+ ) -> Iterator[T]:
452
479
  """
453
480
  Apply a function to each run configuration in the collection and return
454
481
  an iterator of results.
455
482
 
456
483
  Args:
457
- func (Callable[[DictConfig], T]): A function that takes a run
484
+ func (Callable[[DictConfig, P], T]): A function that takes a run
458
485
  configuration and returns a result.
486
+ *args: Additional arguments to pass to the function.
487
+ **kwargs: Additional keyword arguments to pass to the function.
459
488
 
460
489
  Yields:
461
490
  Results obtained by applying the function to each run configuration
462
491
  in the collection.
463
492
  """
464
- return (func(load_config(run)) for run in self._runs)
493
+ return (func(config, *args, **kwargs) for config in self.info.config)
465
494
 
466
- def map_uri(self, func: Callable[[str | None], T]) -> Iterator[T]:
495
+ def map_uri(
496
+ self,
497
+ func: Callable[Concatenate[str | None, P], T],
498
+ *args: P.args,
499
+ **kwargs: P.kwargs,
500
+ ) -> Iterator[T]:
467
501
  """
468
502
  Apply a function to each artifact URI in the collection and return an
469
503
  iterator of results.
@@ -473,16 +507,23 @@ class RunCollection:
473
507
  have an artifact URI, None is passed to the function.
474
508
 
475
509
  Args:
476
- func (Callable[[str | None], T]): A function that takes an
477
- artifact URI (string or None) and returns a result.
510
+ func (Callable[[str | None, P], T]): A function that takes an
511
+ artifact URI (string or None) and returns a result.
512
+ *args: Additional arguments to pass to the function.
513
+ **kwargs: Additional keyword arguments to pass to the function.
478
514
 
479
515
  Yields:
480
516
  Results obtained by applying the function to each artifact URI in the
481
517
  collection.
482
518
  """
483
- return (func(run.info.artifact_uri) for run in self._runs)
519
+ return (func(uri, *args, **kwargs) for uri in self.info.artifact_uri)
484
520
 
485
- def map_dir(self, func: Callable[[str], T]) -> Iterator[T]:
521
+ def map_dir(
522
+ self,
523
+ func: Callable[Concatenate[Path, P], T],
524
+ *args: P.args,
525
+ **kwargs: P.kwargs,
526
+ ) -> Iterator[T]:
486
527
  """
487
528
  Apply a function to each artifact directory in the collection and return
488
529
  an iterator of results.
@@ -492,42 +533,61 @@ class RunCollection:
492
533
  path.
493
534
 
494
535
  Args:
495
- func (Callable[[str], T]): A function that takes an artifact directory
536
+ func (Callable[[Path, P], T]): A function that takes an artifact directory
496
537
  path (string) and returns a result.
538
+ *args: Additional arguments to pass to the function.
539
+ **kwargs: Additional keyword arguments to pass to the function.
497
540
 
498
541
  Yields:
499
542
  Results obtained by applying the function to each artifact directory
500
543
  in the collection.
501
544
  """
502
- return (func(download_artifacts(run_id=run.info.run_id)) for run in self._runs)
545
+ return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir)
503
546
 
504
- def group_by(
505
- self, names: list[str] | None = None, *args
506
- ) -> dict[tuple[str, ...], RunCollection]:
547
+ def group_by(self, *names: str | list[str]) -> dict[tuple[str | None, ...], RunCollection]:
507
548
  """
508
- Group the runs by the specified parameter names and return a dictionary
509
- where the keys are the parameter values and the values are the runs.
549
+ Group runs by specified parameter names.
550
+
551
+ This method groups the runs in the collection based on the values of the
552
+ specified parameters. Each unique combination of parameter values will
553
+ form a key in the returned dictionary.
510
554
 
511
555
  Args:
512
- names (list[str] | None): The parameter names to group by.
513
- *args: Additional positional arguments to specify parameter names.
556
+ *names (str | list[str]): The names of the parameters to group by.
557
+ This can be a single parameter name or multiple names provided
558
+ as separate arguments or as a list.
514
559
 
515
560
  Returns:
516
- A dictionary where the keys are the parameter values and the values
517
- are the runs.
561
+ dict[tuple[str | None, ...], RunCollection]: A dictionary where the keys
562
+ are tuples of parameter values and the values are RunCollection objects
563
+ containing the runs that match those parameter values.
518
564
  """
519
- names = names[:] if names else []
520
- names.extend(args)
521
-
522
- grouped_runs = {}
565
+ grouped_runs: dict[tuple[str | None, ...], list[Run]] = {}
523
566
  for run in self._runs:
524
- key = get_params(run, names)
525
- if key not in grouped_runs:
526
- grouped_runs[key] = []
527
- grouped_runs[key].append(run)
567
+ key = get_params(run, *names)
568
+ grouped_runs.setdefault(key, []).append(run)
528
569
 
529
570
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
530
571
 
572
+ def group_by_values(self, *names: str | list[str]) -> list[RunCollection]:
573
+ """
574
+ Group runs by specified parameter names.
575
+
576
+ This method groups the runs in the collection based on the values of the
577
+ specified parameters. Each unique combination of parameter values will
578
+ form a separate RunCollection in the returned list.
579
+
580
+ Args:
581
+ *names (str | list[str]): The names of the parameters to group by.
582
+ This can be a single parameter name or multiple names provided
583
+ as separate arguments or as a list.
584
+
585
+ Returns:
586
+ list[RunCollection]: A list of RunCollection objects, where each
587
+ object contains runs that match the specified parameter values.
588
+ """
589
+ return list(self.group_by(*names).values())
590
+
531
591
 
532
592
  def _param_matches(run: Run, key: str, value: Any) -> bool:
533
593
  """
@@ -792,11 +852,32 @@ def try_get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
792
852
  raise ValueError(msg)
793
853
 
794
854
 
795
- def get_params(run: Run, names: list[str] | None = None, *args) -> tuple[str, ...]:
796
- names = names[:] if names else []
797
- names.extend(args)
855
+ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
856
+ """
857
+ Retrieve the values of specified parameters from the given run.
858
+
859
+ This function extracts the values of the parameters identified by the
860
+ provided names from the specified run. It can accept both individual
861
+ parameter names and lists of parameter names.
862
+
863
+ Args:
864
+ run (Run): The run object from which to extract parameter values.
865
+ *names (str | list[str]): The names of the parameters to retrieve.
866
+ This can be a single parameter name or multiple names provided
867
+ as separate arguments or as a list.
868
+
869
+ Returns:
870
+ tuple[str | None, ...]: A tuple containing the values of the specified
871
+ parameters in the order they were provided.
872
+ """
873
+ names_ = []
874
+ for name in names:
875
+ if isinstance(name, list):
876
+ names_.extend(name)
877
+ else:
878
+ names_.append(name)
798
879
 
799
- return tuple(run.data.params[name] for name in names)
880
+ return tuple(run.data.params.get(name) for name in names_)
800
881
 
801
882
 
802
883
  def get_param_names(runs: list[Run]) -> list[str]:
@@ -846,33 +927,3 @@ def get_param_dict(runs: list[Run]) -> dict[str, list[str]]:
846
927
  params[name] = sorted(set(it))
847
928
 
848
929
  return params
849
-
850
-
851
- def load_config(run: Run) -> DictConfig:
852
- """
853
- Load the configuration for a given run.
854
-
855
- This function loads the configuration for the provided Run instance
856
- by downloading the configuration file from the MLflow artifacts and
857
- loading it using OmegaConf. It returns an empty config if
858
- `.hydra/config.yaml` is not found in the run's artifact directory.
859
-
860
- Args:
861
- run (Run): The Run instance for which to load the configuration.
862
-
863
- Returns:
864
- The loaded configuration as a DictConfig object. Returns an empty
865
- DictConfig if the configuration file is not found.
866
- """
867
- run_id = run.info.run_id
868
- return _load_config(run_id)
869
-
870
-
871
- @cache
872
- def _load_config(run_id: str) -> DictConfig:
873
- try:
874
- path = download_artifacts(run_id=run_id, artifact_path=".hydra/config.yaml")
875
- except OSError:
876
- return DictConfig({})
877
-
878
- return OmegaConf.load(path) # type: ignore
@@ -77,6 +77,7 @@ async def test_monitor_file_changes(tmp_path: Path, write_soon: Callable[[Path],
77
77
  await asyncio.sleep(1)
78
78
  stop_event.set()
79
79
  await monitor_task
80
+ await asyncio.sleep(1)
80
81
 
81
82
  assert len(changes_detected) > 0
82
83
 
@@ -1,15 +1,17 @@
1
+ import time
2
+ from pathlib import Path
1
3
  from unittest.mock import MagicMock, patch
2
4
 
3
5
  import mlflow
4
6
  import pytest
5
7
 
6
8
  from hydraflow.context import log_run, start_run, watch
7
- from hydraflow.runs import RunCollection
9
+ from hydraflow.run_collection import RunCollection
8
10
 
9
11
 
10
12
  @pytest.fixture
11
13
  def runs(monkeypatch, tmp_path):
12
- from hydraflow.runs import list_runs
14
+ from hydraflow.run_collection import list_runs
13
15
 
14
16
  monkeypatch.chdir(tmp_path)
15
17
 
@@ -17,7 +19,7 @@ def runs(monkeypatch, tmp_path):
17
19
  patch("hydraflow.context.HydraConfig.get") as mock_hydra_config,
18
20
  patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
19
21
  ):
20
- mock_hydra_config.return_value.runtime.output_dir = "/tmp"
22
+ mock_hydra_config.return_value.runtime.output_dir = tmp_path.as_posix()
21
23
  mock_log_artifacts.return_value = None
22
24
 
23
25
  mlflow.set_experiment("test_run")
@@ -49,7 +51,7 @@ def test_runs_params_dict(runs: RunCollection, i: int):
49
51
  assert runs[i].data.params["d.i"] == str(i)
50
52
 
51
53
 
52
- def test_log_run_error_handling():
54
+ def test_log_run_error_handling(tmp_path: Path):
53
55
  config = MagicMock()
54
56
  config.some_param = "value"
55
57
 
@@ -59,7 +61,7 @@ def test_log_run_error_handling():
59
61
  patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
60
62
  ):
61
63
  mock_log_params.side_effect = Exception("Test exception")
62
- mock_hydra_config.return_value.runtime.output_dir = "/tmp"
64
+ mock_hydra_config.return_value.runtime.output_dir = tmp_path.as_posix()
63
65
  mock_log_artifacts.return_value = None
64
66
 
65
67
  with pytest.raises(Exception, match="Test exception"):
@@ -67,14 +69,20 @@ def test_log_run_error_handling():
67
69
  pass
68
70
 
69
71
 
70
- def test_watch_error_handling():
71
- func = MagicMock()
72
- dir = "/tmp"
72
+ def test_watch_context_manager(tmp_path: Path):
73
+ test_dir = tmp_path / "test_watch"
74
+ test_dir.mkdir(parents=True, exist_ok=True)
75
+ test_file = test_dir / "test_file.txt"
73
76
 
74
- with patch("hydraflow.context.Observer") as mock_observer:
75
- mock_observer_instance = mock_observer.return_value
76
- mock_observer_instance.start.side_effect = Exception("Test exception")
77
+ called = []
77
78
 
78
- with pytest.raises(Exception, match="Test exception"):
79
- with watch(func, dir):
80
- pass
79
+ def mock_func(path: Path):
80
+ assert path == test_file
81
+ called.append(path)
82
+
83
+ with watch(mock_func, test_dir):
84
+ test_file.write_text("new content")
85
+ time.sleep(1)
86
+
87
+ assert len(called) == 1
88
+ assert called[0] == test_file
@@ -0,0 +1,51 @@
1
+ from pathlib import Path
2
+
3
+ import mlflow
4
+ import pytest
5
+
6
+ from hydraflow.run_collection import RunCollection
7
+
8
+
9
+ @pytest.fixture
10
+ def runs(monkeypatch, tmp_path):
11
+ from hydraflow.run_collection import search_runs
12
+
13
+ monkeypatch.chdir(tmp_path)
14
+
15
+ mlflow.set_experiment("test_info")
16
+ for x in range(3):
17
+ with mlflow.start_run(run_name=f"{x}"):
18
+ mlflow.log_param("p", x)
19
+ mlflow.log_metric("metric1", x + 1)
20
+ mlflow.log_metric("metric2", x + 2)
21
+
22
+ x = search_runs()
23
+ assert isinstance(x, RunCollection)
24
+ return x
25
+
26
+
27
+ def test_info_run_id(runs: RunCollection):
28
+ assert len(runs.info.run_id) == 3
29
+
30
+
31
+ def test_info_params(runs: RunCollection):
32
+ assert runs.info.params == [{"p": "0"}, {"p": "1"}, {"p": "2"}]
33
+
34
+
35
+ def test_info_metrics(runs: RunCollection):
36
+ m = runs.info.metrics
37
+ assert m[0] == {"metric1": 1, "metric2": 2}
38
+ assert m[1] == {"metric1": 2, "metric2": 3}
39
+ assert m[2] == {"metric1": 3, "metric2": 4}
40
+
41
+
42
+ def test_info_artifact_uri(runs: RunCollection):
43
+ uri = runs.info.artifact_uri
44
+ assert all(u.startswith("file://") for u in uri) # type: ignore
45
+ assert all(u.endswith("/artifacts") for u in uri) # type: ignore
46
+
47
+
48
+ def test_info_artifact_dir(runs: RunCollection):
49
+ dir = runs.info.artifact_dir
50
+ assert all(isinstance(d, Path) for d in dir)
51
+ assert all(d.stem == "artifacts" for d in dir) # type: ignore
@@ -49,7 +49,7 @@ def read_log(run_id: str, path: str) -> str:
49
49
 
50
50
 
51
51
  def test_load_config(run: Run):
52
- from hydraflow.runs import load_config
52
+ from hydraflow.info import load_config
53
53
 
54
54
  log = read_log(run.info.run_id, "log_run.log")
55
55
  assert "START" in log
@@ -0,0 +1,12 @@
1
+ import sys
2
+ from subprocess import run
3
+
4
+ import pytest
5
+
6
+
7
+ @pytest.mark.skipif(
8
+ sys.platform == "win32", reason="'cp932' codec can't encode character '\\u2807'"
9
+ )
10
+ def test_progress_bar():
11
+ cp = run([sys.executable, "-m", "hydraflow.progress"])
12
+ assert cp.returncode == 0
@@ -5,14 +5,13 @@ from pathlib import Path
5
5
  import mlflow
6
6
  import pytest
7
7
  from mlflow.entities import Run
8
- from omegaconf import DictConfig
9
8
 
10
- from hydraflow.runs import RunCollection
9
+ from hydraflow.run_collection import RunCollection
11
10
 
12
11
 
13
12
  @pytest.fixture
14
13
  def runs(monkeypatch, tmp_path):
15
- from hydraflow.runs import search_runs
14
+ from hydraflow.run_collection import search_runs
16
15
 
17
16
  monkeypatch.chdir(tmp_path)
18
17
 
@@ -39,13 +38,13 @@ def test_search_runs_sorted(run_list: list[Run]):
39
38
 
40
39
 
41
40
  def test_filter_none(run_list: list[Run]):
42
- from hydraflow.runs import filter_runs
41
+ from hydraflow.run_collection import filter_runs
43
42
 
44
43
  assert run_list == filter_runs(run_list)
45
44
 
46
45
 
47
46
  def test_filter_one(run_list: list[Run]):
48
- from hydraflow.runs import filter_runs
47
+ from hydraflow.run_collection import filter_runs
49
48
 
50
49
  assert len(run_list) == 6
51
50
  x = filter_runs(run_list, {"p": 1})
@@ -55,7 +54,7 @@ def test_filter_one(run_list: list[Run]):
55
54
 
56
55
 
57
56
  def test_filter_all(run_list: list[Run]):
58
- from hydraflow.runs import filter_runs
57
+ from hydraflow.run_collection import filter_runs
59
58
 
60
59
  assert len(run_list) == 6
61
60
  x = filter_runs(run_list, {"q": 0})
@@ -65,28 +64,28 @@ def test_filter_all(run_list: list[Run]):
65
64
 
66
65
 
67
66
  def test_filter_list(run_list: list[Run]):
68
- from hydraflow.runs import filter_runs
67
+ from hydraflow.run_collection import filter_runs
69
68
 
70
69
  x = filter_runs(run_list, p=[0, 4, 5])
71
70
  assert len(x) == 3
72
71
 
73
72
 
74
73
  def test_filter_tuple(run_list: list[Run]):
75
- from hydraflow.runs import filter_runs
74
+ from hydraflow.run_collection import filter_runs
76
75
 
77
76
  x = filter_runs(run_list, p=(1, 3))
78
77
  assert len(x) == 2
79
78
 
80
79
 
81
80
  def test_filter_invalid_param(run_list: list[Run]):
82
- from hydraflow.runs import filter_runs
81
+ from hydraflow.run_collection import filter_runs
83
82
 
84
83
  x = filter_runs(run_list, {"invalid": 0})
85
84
  assert len(x) == 6
86
85
 
87
86
 
88
87
  def test_find_run(run_list: list[Run]):
89
- from hydraflow.runs import find_run, try_find_run
88
+ from hydraflow.run_collection import find_run, try_find_run
90
89
 
91
90
  x = find_run(run_list, {"r": 1})
92
91
  assert isinstance(x, Run)
@@ -100,20 +99,20 @@ def test_find_run(run_list: list[Run]):
100
99
 
101
100
 
102
101
  def test_find_run_none(run_list: list[Run]):
103
- from hydraflow.runs import find_run
102
+ from hydraflow.run_collection import find_run
104
103
 
105
104
  with pytest.raises(ValueError):
106
105
  find_run(run_list, {"r": 10})
107
106
 
108
107
 
109
108
  def test_try_find_run_none_empty(run_list: list[Run]):
110
- from hydraflow.runs import try_find_run
109
+ from hydraflow.run_collection import try_find_run
111
110
 
112
111
  assert try_find_run([]) is None
113
112
 
114
113
 
115
114
  def test_find_last_run(run_list: list[Run]):
116
- from hydraflow.runs import find_last_run, try_find_last_run
115
+ from hydraflow.run_collection import find_last_run, try_find_last_run
117
116
 
118
117
  x = find_last_run(run_list, {"r": 1})
119
118
  assert isinstance(x, Run)
@@ -127,20 +126,20 @@ def test_find_last_run(run_list: list[Run]):
127
126
 
128
127
 
129
128
  def test_find_last_run_none(run_list: list[Run]):
130
- from hydraflow.runs import find_last_run
129
+ from hydraflow.run_collection import find_last_run
131
130
 
132
131
  with pytest.raises(ValueError):
133
132
  find_last_run(run_list, {"r": 10})
134
133
 
135
134
 
136
135
  def test_try_find_last_run_none(run_list: list[Run]):
137
- from hydraflow.runs import try_find_last_run
136
+ from hydraflow.run_collection import try_find_last_run
138
137
 
139
138
  assert try_find_last_run([]) is None
140
139
 
141
140
 
142
141
  def test_get_run(run_list: list[Run]):
143
- from hydraflow.runs import get_run
142
+ from hydraflow.run_collection import get_run
144
143
 
145
144
  run = get_run(run_list, {"p": 4})
146
145
  assert isinstance(run, Run)
@@ -148,7 +147,7 @@ def test_get_run(run_list: list[Run]):
148
147
 
149
148
 
150
149
  def test_get_run_error(run_list: list[Run]):
151
- from hydraflow.runs import get_run
150
+ from hydraflow.run_collection import get_run
152
151
 
153
152
  with pytest.raises(ValueError):
154
153
  get_run(run_list, {"q": 0})
@@ -158,20 +157,30 @@ def test_get_run_error(run_list: list[Run]):
158
157
 
159
158
 
160
159
  def test_try_get_run_none(run_list: list[Run]):
161
- from hydraflow.runs import try_get_run
160
+ from hydraflow.run_collection import try_get_run
162
161
 
163
162
  assert try_get_run(run_list, {"q": -1}) is None
164
163
 
165
164
 
166
165
  def test_try_get_run_error(run_list: list[Run]):
167
- from hydraflow.runs import try_get_run
166
+ from hydraflow.run_collection import try_get_run
168
167
 
169
168
  with pytest.raises(ValueError):
170
169
  try_get_run(run_list, {"q": 0})
171
170
 
172
171
 
172
+ def test_get_params(run_list: list[Run]):
173
+ from hydraflow.run_collection import get_params
174
+
175
+ assert get_params(run_list[1], "p") == ("1",)
176
+ assert get_params(run_list[2], "p", "q") == ("2", "0")
177
+ assert get_params(run_list[3], ["p", "q"]) == ("3", "0")
178
+ assert get_params(run_list[4], "p", ["q", "r"]) == ("4", "0", "1")
179
+ assert get_params(run_list[5], ["a", "q"], "r") == (None, "None", "2")
180
+
181
+
173
182
  def test_get_param_names(run_list: list[Run]):
174
- from hydraflow.runs import get_param_names
183
+ from hydraflow.run_collection import get_param_names
175
184
 
176
185
  params = get_param_names(run_list)
177
186
  assert len(params) == 3
@@ -181,7 +190,7 @@ def test_get_param_names(run_list: list[Run]):
181
190
 
182
191
 
183
192
  def test_get_param_dict(run_list: list[Run]):
184
- from hydraflow.runs import get_param_dict
193
+ from hydraflow.run_collection import get_param_dict
185
194
 
186
195
  params = get_param_dict(run_list)
187
196
  assert len(params["p"]) == 6
@@ -250,7 +259,7 @@ def test_runs_filter(runs: RunCollection):
250
259
 
251
260
 
252
261
  def test_runs_get(runs: RunCollection):
253
- from hydraflow.runs import Run
262
+ from hydraflow.run_collection import Run
254
263
 
255
264
  run = runs.get({"p": 4})
256
265
  assert isinstance(run, Run)
@@ -283,7 +292,7 @@ def test_runs_get_params_dict(runs: RunCollection):
283
292
 
284
293
 
285
294
  def test_runs_find(runs: RunCollection):
286
- from hydraflow.runs import Run
295
+ from hydraflow.run_collection import Run
287
296
 
288
297
  run = runs.find({"r": 0})
289
298
  assert isinstance(run, Run)
@@ -304,7 +313,7 @@ def test_runs_try_find_none(runs: RunCollection):
304
313
 
305
314
 
306
315
  def test_runs_find_last(runs: RunCollection):
307
- from hydraflow.runs import Run
316
+ from hydraflow.run_collection import Run
308
317
 
309
318
  run = runs.find_last({"r": 0})
310
319
  assert isinstance(run, Run)
@@ -333,7 +342,7 @@ def runs2(monkeypatch, tmp_path):
333
342
 
334
343
 
335
344
  def test_list_runs(runs, runs2):
336
- from hydraflow.runs import list_runs
345
+ from hydraflow.run_collection import list_runs
337
346
 
338
347
  mlflow.set_experiment("test_run")
339
348
  all_runs = list_runs()
@@ -345,7 +354,7 @@ def test_list_runs(runs, runs2):
345
354
 
346
355
 
347
356
  def test_list_runs_empty_list(runs, runs2):
348
- from hydraflow.runs import list_runs
357
+ from hydraflow.run_collection import list_runs
349
358
 
350
359
  all_runs = list_runs([])
351
360
  assert len(all_runs) == 9
@@ -353,14 +362,14 @@ def test_list_runs_empty_list(runs, runs2):
353
362
 
354
363
  @pytest.mark.parametrize(["name", "n"], [("test_run", 6), ("test_run2", 3)])
355
364
  def test_list_runs_list(runs, runs2, name, n):
356
- from hydraflow.runs import list_runs
365
+ from hydraflow.run_collection import list_runs
357
366
 
358
367
  filtered_runs = list_runs(experiment_names=[name])
359
368
  assert len(filtered_runs) == n
360
369
 
361
370
 
362
371
  def test_list_runs_none(runs, runs2):
363
- from hydraflow.runs import list_runs
372
+ from hydraflow.run_collection import list_runs
364
373
 
365
374
  no_runs = list_runs(experiment_names=["non_existent_experiment"])
366
375
  assert len(no_runs) == 0
@@ -372,16 +381,20 @@ def test_run_collection_map(runs: RunCollection):
372
381
  assert all(isinstance(run_id, str) for run_id in results)
373
382
 
374
383
 
384
+ def test_run_collection_map_args(runs: RunCollection):
385
+ results = list(runs.map(lambda run, x: run.info.run_id + x, "test"))
386
+ assert all(x.endswith("test") for x in results)
387
+
388
+
375
389
  def test_run_collection_map_run_id(runs: RunCollection):
376
390
  results = list(runs.map_run_id(lambda run_id: run_id))
377
391
  assert len(results) == len(runs._runs)
378
392
  assert all(isinstance(run_id, str) for run_id in results)
379
393
 
380
394
 
381
- def test_run_collection_map_config(runs: RunCollection):
382
- results = list(runs.map_config(lambda config: config))
383
- assert len(results) == len(runs._runs)
384
- assert all(isinstance(config, DictConfig) for config in results)
395
+ def test_run_collection_map_run_id_kwargs(runs: RunCollection):
396
+ results = list(runs.map_run_id(lambda run_id, x: x + run_id, x="test"))
397
+ assert all(x.startswith("test") for x in results)
385
398
 
386
399
 
387
400
  def test_run_collection_map_uri(runs: RunCollection):
@@ -391,9 +404,10 @@ def test_run_collection_map_uri(runs: RunCollection):
391
404
 
392
405
 
393
406
  def test_run_collection_map_dir(runs: RunCollection):
394
- results = list(runs.map_dir(lambda dir_path: dir_path))
407
+ results = list(runs.map_dir(lambda dir_path, x: dir_path / x, "a.csv"))
395
408
  assert len(results) == len(runs._runs)
396
- assert all(isinstance(dir_path, str) for dir_path in results)
409
+ assert all(isinstance(dir_path, Path) for dir_path in results)
410
+ assert all(dir_path.stem == "a" for dir_path in results)
397
411
 
398
412
 
399
413
  def test_run_collection_sort(runs: RunCollection):
@@ -427,15 +441,53 @@ def test_run_collection_group_by(runs: RunCollection):
427
441
  assert grouped[("0",)][0] == runs[0]
428
442
  assert grouped[("1",)][0] == runs[1]
429
443
 
430
- grouped = runs.group_by(["q"])
444
+ grouped = runs.group_by("q")
431
445
  assert len(grouped) == 2
432
446
 
433
- grouped = runs.group_by(["r"])
447
+ grouped = runs.group_by("r")
434
448
  assert len(grouped) == 3
435
449
 
436
450
 
437
- # def test_hydra_output_dir_error(runs_list: list[Run]):
438
- # from hydraflow.runs import get_hydra_output_dir
451
+ def test_filter_runs_empty_list():
452
+ from hydraflow.run_collection import filter_runs
453
+
454
+ x = filter_runs([], p=[0, 1, 2])
455
+ assert x == []
456
+
457
+
458
+ def test_filter_runs_no_match(run_list: list[Run]):
459
+ from hydraflow.run_collection import filter_runs
460
+
461
+ x = filter_runs(run_list, p=[10, 11, 12])
462
+ assert x == []
463
+
464
+
465
+ def test_get_run_no_match(run_list: list[Run]):
466
+ from hydraflow.run_collection import get_run
467
+
468
+ with pytest.raises(ValueError):
469
+ get_run(run_list, {"p": 10})
470
+
471
+
472
+ def test_get_run_multiple_params(run_list: list[Run]):
473
+ from hydraflow.run_collection import get_run
474
+
475
+ run = get_run(run_list, {"p": 4, "q": 0})
476
+ assert isinstance(run, Run)
477
+ assert run.data.params["p"] == "4"
478
+ assert run.data.params["q"] == "0"
479
+
439
480
 
440
- # with pytest.raises(FileNotFoundError):
441
- # get_hydra_output_dir(runs_list[0])
481
+ def test_try_get_run_no_match(run_list: list[Run]):
482
+ from hydraflow.run_collection import try_get_run
483
+
484
+ assert try_get_run(run_list, {"p": 10}) is None
485
+
486
+
487
+ def test_try_get_run_multiple_params(run_list: list[Run]):
488
+ from hydraflow.run_collection import try_get_run
489
+
490
+ run = try_get_run(run_list, {"p": 4, "q": 0})
491
+ assert isinstance(run, Run)
492
+ assert run.data.params["p"] == "4"
493
+ assert run.data.params["q"] == "0"
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import subprocess
4
+ import time
4
5
  from pathlib import Path
5
6
 
6
7
  import pytest
@@ -21,6 +22,7 @@ def test_watch(dir, monkeypatch, tmp_path):
21
22
 
22
23
  with watch(func, dir if isinstance(dir, str) else dir()):
23
24
  subprocess.check_call(["python", file])
25
+ time.sleep(1)
24
26
 
25
- assert results[0][0] == "watch.txt"
26
- assert results[0][1] == "watch"
27
+ assert results[0][0] == "watch.txt" # type: ignore
28
+ assert results[0][1] == "watch" # type: ignore
File without changes
File without changes
File without changes
File without changes