hydraflow 0.2.14__tar.gz → 0.2.16__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (35) hide show
  1. {hydraflow-0.2.14 → hydraflow-0.2.16}/.gitignore +2 -1
  2. {hydraflow-0.2.14 → hydraflow-0.2.16}/PKG-INFO +3 -10
  3. {hydraflow-0.2.14 → hydraflow-0.2.16}/pyproject.toml +27 -13
  4. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/asyncio.py +39 -19
  5. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/config.py +3 -3
  6. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/context.py +28 -20
  7. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/info.py +1 -1
  8. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/mlflow.py +4 -2
  9. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/progress.py +2 -2
  10. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/run_collection.py +26 -9
  11. hydraflow-0.2.16/tests/__init__.py +0 -0
  12. hydraflow-0.2.16/tests/scripts/__init__.py +0 -0
  13. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_app.py +8 -6
  14. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_asyncio.py +9 -3
  15. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_config.py +5 -1
  16. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_log_run.py +1 -2
  17. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_progress.py +2 -2
  18. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_run_collection.py +11 -7
  19. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_watch.py +1 -1
  20. {hydraflow-0.2.14 → hydraflow-0.2.16}/.devcontainer/devcontainer.json +0 -0
  21. {hydraflow-0.2.14 → hydraflow-0.2.16}/.devcontainer/postCreate.sh +0 -0
  22. {hydraflow-0.2.14 → hydraflow-0.2.16}/.devcontainer/starship.toml +0 -0
  23. {hydraflow-0.2.14 → hydraflow-0.2.16}/.gitattributes +0 -0
  24. {hydraflow-0.2.14 → hydraflow-0.2.16}/LICENSE +0 -0
  25. {hydraflow-0.2.14 → hydraflow-0.2.16}/README.md +0 -0
  26. {hydraflow-0.2.14 → hydraflow-0.2.16}/mkdocs.yml +0 -0
  27. {hydraflow-0.2.14 → hydraflow-0.2.16}/src/hydraflow/__init__.py +0 -0
  28. /hydraflow-0.2.14/tests/scripts/__init__.py → /hydraflow-0.2.16/src/hydraflow/py.typed +0 -0
  29. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/scripts/app.py +0 -0
  30. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/scripts/progress.py +0 -0
  31. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/scripts/watch.py +0 -0
  32. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_context.py +0 -0
  33. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_info.py +0 -0
  34. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_mlflow.py +0 -0
  35. {hydraflow-0.2.14 → hydraflow-0.2.16}/tests/test_version.py +0 -0
@@ -3,4 +3,5 @@
3
3
  .venv/
4
4
  __pycache__/
5
5
  dist/
6
- lcov.info
6
+ lcov.info
7
+ uv.lock
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.14
3
+ Version: 0.2.16
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -14,19 +14,12 @@ Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
15
  Classifier: Programming Language :: Python :: 3.12
16
16
  Requires-Python: >=3.10
17
- Requires-Dist: hydra-core>1.3
17
+ Requires-Dist: hydra-core>=1.3
18
18
  Requires-Dist: joblib
19
- Requires-Dist: mlflow>2.15
19
+ Requires-Dist: mlflow>=2.15
20
20
  Requires-Dist: rich
21
- Requires-Dist: setuptools
22
21
  Requires-Dist: watchdog
23
22
  Requires-Dist: watchfiles
24
- Provides-Extra: dev
25
- Requires-Dist: pytest-asyncio; extra == 'dev'
26
- Requires-Dist: pytest-clarity; extra == 'dev'
27
- Requires-Dist: pytest-cov; extra == 'dev'
28
- Requires-Dist: pytest-randomly; extra == 'dev'
29
- Requires-Dist: pytest-xdist; extra == 'dev'
30
23
  Description-Content-Type: text/markdown
31
24
 
32
25
  # Hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.14"
7
+ version = "0.2.16"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -18,29 +18,29 @@ classifiers = [
18
18
  ]
19
19
  requires-python = ">=3.10"
20
20
  dependencies = [
21
- "hydra-core>1.3",
21
+ "hydra-core>=1.3",
22
22
  "joblib",
23
- "mlflow>2.15",
23
+ "mlflow>=2.15",
24
24
  "rich",
25
- "setuptools",
26
25
  "watchdog",
27
26
  "watchfiles",
28
27
  ]
29
28
 
30
- [project.optional-dependencies]
31
- dev = [
29
+ [project.urls]
30
+ Documentation = "https://github.com/daizutabi/hydraflow"
31
+ Source = "https://github.com/daizutabi/hydraflow"
32
+ Issues = "https://github.com/daizutabi/hydraflow/issues"
33
+
34
+ [tool.uv]
35
+ dev-dependencies = [
32
36
  "pytest-asyncio",
33
37
  "pytest-clarity",
34
38
  "pytest-cov",
35
39
  "pytest-randomly",
36
40
  "pytest-xdist",
41
+ "ruff",
37
42
  ]
38
43
 
39
- [project.urls]
40
- Documentation = "https://github.com/daizutabi/hydraflow"
41
- Source = "https://github.com/daizutabi/hydraflow"
42
- Issues = "https://github.com/daizutabi/hydraflow/issues"
43
-
44
44
  [tool.hatch.build.targets.sdist]
45
45
  exclude = ["/.github", "/docs"]
46
46
 
@@ -62,7 +62,21 @@ exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
62
62
 
63
63
  [tool.ruff]
64
64
  line-length = 88
65
- target-version = "py312"
65
+ target-version = "py310"
66
66
 
67
67
  [tool.ruff.lint]
68
- unfixable = ["F401", "RUF100"]
68
+ select = ["ALL"]
69
+ ignore = [
70
+ "ANN003",
71
+ "ANN401",
72
+ "ARG002",
73
+ "B904",
74
+ "D",
75
+ "EM101",
76
+ "PGH003",
77
+ "TRY003",
78
+ ]
79
+ exclude = ["tests/scripts/*.py"]
80
+
81
+ [tool.ruff.lint.per-file-ignores]
82
+ "tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM117", "SLF"]
@@ -42,7 +42,10 @@ async def execute_command(
42
42
  """
43
43
  try:
44
44
  process = await asyncio.create_subprocess_exec(
45
- program, *args, stdout=PIPE, stderr=PIPE
45
+ program,
46
+ *args,
47
+ stdout=PIPE,
48
+ stderr=PIPE,
46
49
  )
47
50
  await asyncio.gather(
48
51
  process_stream(process.stdout, stdout),
@@ -51,7 +54,8 @@ async def execute_command(
51
54
  returncode = await process.wait()
52
55
 
53
56
  except Exception as e:
54
- logger.error(f"Error running command: {e}")
57
+ msg = f"Error running command: {e}"
58
+ logger.exception(msg)
55
59
  returncode = 1
56
60
 
57
61
  finally:
@@ -103,11 +107,15 @@ async def monitor_file_changes(
103
107
  str_paths = [str(path) for path in paths]
104
108
  try:
105
109
  async for changes in watchfiles.awatch(
106
- *str_paths, stop_event=stop_event, **awatch_kwargs
110
+ *str_paths,
111
+ stop_event=stop_event,
112
+ **awatch_kwargs,
107
113
  ):
108
114
  callback(changes)
109
115
  except Exception as e:
110
- logger.error(f"Error watching files: {e}")
116
+ msg = f"Error watching files: {e}"
117
+ logger.exception(msg)
118
+ raise
111
119
 
112
120
 
113
121
  async def run_and_monitor(
@@ -134,13 +142,16 @@ async def run_and_monitor(
134
142
  stop_event = asyncio.Event()
135
143
  run_task = asyncio.create_task(
136
144
  execute_command(
137
- program, *args, stop_event=stop_event, stdout=stdout, stderr=stderr
138
- )
145
+ program,
146
+ *args,
147
+ stop_event=stop_event,
148
+ stdout=stdout,
149
+ stderr=stderr,
150
+ ),
139
151
  )
140
152
  if watch and paths:
141
- monitor_task = asyncio.create_task(
142
- monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
143
- )
153
+ coro = monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
154
+ monitor_task = asyncio.create_task(coro)
144
155
  else:
145
156
  monitor_task = None
146
157
 
@@ -151,7 +162,10 @@ async def run_and_monitor(
151
162
  await run_task
152
163
 
153
164
  except Exception as e:
154
- logger.error(f"Error in run_and_monitor: {e}")
165
+ msg = f"Error in run_and_monitor: {e}"
166
+ logger.exception(msg)
167
+ raise
168
+
155
169
  finally:
156
170
  stop_event.set()
157
171
  await run_task
@@ -173,18 +187,24 @@ def run(
173
187
  """
174
188
  Run a command synchronously and optionally watch for file changes.
175
189
 
176
- This function is a synchronous wrapper around the asynchronous `run_and_monitor` function.
177
- It runs a specified command and optionally monitors specified paths for file changes,
178
- invoking the provided callbacks for standard output, standard error, and file changes.
190
+ This function is a synchronous wrapper around the asynchronous
191
+ `run_and_monitor` function. It runs a specified command and optionally
192
+ monitors specified paths for file changes, invoking the provided callbacks for
193
+ standard output, standard error, and file changes.
179
194
 
180
195
  Args:
181
196
  program (str): The program to run.
182
197
  *args (str): Arguments for the program.
183
- stdout (Callable[[str], None] | None): Callback for handling standard output lines.
184
- stderr (Callable[[str], None] | None): Callback for handling standard error lines.
185
- watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for handling file changes.
186
- paths (list[str | Path] | None): List of paths to monitor for file changes.
187
- **awatch_kwargs: Additional keyword arguments to pass to `watchfiles.awatch`.
198
+ stdout (Callable[[str], None] | None): Callback for handling standard
199
+ output lines.
200
+ stderr (Callable[[str], None] | None): Callback for handling standard
201
+ error lines.
202
+ watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for
203
+ handling file changes.
204
+ paths (list[str | Path] | None): List of paths to monitor for file
205
+ changes.
206
+ **awatch_kwargs: Additional keyword arguments to pass to
207
+ `watchfiles.awatch`.
188
208
 
189
209
  Returns:
190
210
  int: The return code of the process.
@@ -201,5 +221,5 @@ def run(
201
221
  watch=watch,
202
222
  paths=paths,
203
223
  **awatch_kwargs,
204
- )
224
+ ),
205
225
  )
@@ -33,7 +33,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
33
33
  if config is None:
34
34
  return
35
35
 
36
- if not isinstance(config, (DictConfig, ListConfig)):
36
+ if not isinstance(config, DictConfig | ListConfig):
37
37
  config = OmegaConf.create(config) # type: ignore
38
38
 
39
39
  yield from _iter_params(config, prefix)
@@ -62,8 +62,8 @@ def _is_param(value: object) -> bool:
62
62
  if isinstance(value, DictConfig):
63
63
  return False
64
64
 
65
- if isinstance(value, ListConfig):
66
- if any(isinstance(v, (DictConfig, ListConfig)) for v in value):
65
+ if isinstance(value, ListConfig): # noqa: SIM102
66
+ if any(isinstance(v, DictConfig | ListConfig) for v in value):
67
67
  return False
68
68
 
69
69
  return True
@@ -75,7 +75,8 @@ def log_run(
75
75
  yield
76
76
 
77
77
  except Exception as e:
78
- log.error(f"Error during log_run: {e}")
78
+ msg = f"Error during log_run: {e}"
79
+ log.exception(msg)
79
80
  raise
80
81
 
81
82
  finally:
@@ -84,7 +85,7 @@ def log_run(
84
85
 
85
86
 
86
87
  @contextmanager
87
- def start_run(
88
+ def start_run( # noqa: PLR0913
88
89
  config: object,
89
90
  *,
90
91
  run_id: str | None = None,
@@ -112,8 +113,10 @@ def start_run(
112
113
  parent_run_id (str | None): The parent run ID. Defaults to None.
113
114
  tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
114
115
  description (str | None): A description of the run. Defaults to None.
115
- log_system_metrics (bool | None): Whether to log system metrics. Defaults to None.
116
- synchronous (bool | None): Whether to log parameters synchronously. Defaults to None.
116
+ log_system_metrics (bool | None): Whether to log system metrics.
117
+ Defaults to None.
118
+ synchronous (bool | None): Whether to log parameters synchronously.
119
+ Defaults to None.
117
120
 
118
121
  Yields:
119
122
  Run: An MLflow Run object representing the started run.
@@ -128,24 +131,27 @@ def start_run(
128
131
  - `log_run`: A context manager to log parameters and manage the MLflow
129
132
  run context.
130
133
  """
131
- with mlflow.start_run(
132
- run_id=run_id,
133
- experiment_id=experiment_id,
134
- run_name=run_name,
135
- nested=nested,
136
- parent_run_id=parent_run_id,
137
- tags=tags,
138
- description=description,
139
- log_system_metrics=log_system_metrics,
140
- ) as run:
141
- with log_run(config, synchronous=synchronous):
142
- yield run
134
+ with (
135
+ mlflow.start_run(
136
+ run_id=run_id,
137
+ experiment_id=experiment_id,
138
+ run_name=run_name,
139
+ nested=nested,
140
+ parent_run_id=parent_run_id,
141
+ tags=tags,
142
+ description=description,
143
+ log_system_metrics=log_system_metrics,
144
+ ) as run,
145
+ log_run(config, synchronous=synchronous),
146
+ ):
147
+ yield run
143
148
 
144
149
 
145
150
  @contextmanager
146
151
  def watch(
147
152
  callback: Callable[[Path], None],
148
- dir: Path | str = "",
153
+ dir: Path | str = "", # noqa: A002
154
+ *,
149
155
  timeout: int = 60,
150
156
  ignore_patterns: list[str] | None = None,
151
157
  ignore_log: bool = True,
@@ -178,9 +184,9 @@ def watch(
178
184
  pass
179
185
  ```
180
186
  """
181
- dir = dir or get_artifact_dir()
187
+ dir = dir or get_artifact_dir() # noqa: A001
182
188
  if isinstance(dir, Path):
183
- dir = dir.as_posix()
189
+ dir = dir.as_posix() # noqa: A001
184
190
 
185
191
  handler = Handler(callback, ignore_patterns=ignore_patterns, ignore_log=ignore_log)
186
192
  observer = Observer()
@@ -191,7 +197,8 @@ def watch(
191
197
  yield
192
198
 
193
199
  except Exception as e:
194
- log.error(f"Error during watch: {e}")
200
+ msg = f"Error during watch: {e}"
201
+ log.exception(msg)
195
202
  raise
196
203
 
197
204
  finally:
@@ -210,6 +217,7 @@ class Handler(PatternMatchingEventHandler):
210
217
  def __init__(
211
218
  self,
212
219
  func: Callable[[Path], None],
220
+ *,
213
221
  ignore_patterns: list[str] | None = None,
214
222
  ignore_log: bool = True,
215
223
  ) -> None:
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
15
15
 
16
16
 
17
17
  class RunCollectionInfo:
18
- def __init__(self, runs: RunCollection):
18
+ def __init__(self, runs: RunCollection) -> None:
19
19
  self._runs = runs
20
20
 
21
21
  @property
@@ -81,7 +81,8 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
81
81
  mlflow.log_param(key, value, synchronous=synchronous)
82
82
 
83
83
 
84
- def search_runs(
84
+ def search_runs( # noqa: PLR0913
85
+ *,
85
86
  experiment_ids: list[str] | None = None,
86
87
  filter_string: str = "",
87
88
  run_view_type: int = ViewType.ACTIVE_ONLY,
@@ -148,7 +149,8 @@ def search_runs(
148
149
 
149
150
 
150
151
  def list_runs(
151
- experiment_names: str | list[str] | None = None, n_jobs: int = 0
152
+ experiment_names: str | list[str] | None = None,
153
+ n_jobs: int = 0,
152
154
  ) -> RunCollection:
153
155
  """
154
156
  List all runs for the specified experiments.
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
31
31
 
32
32
  # https://github.com/jonghwanhyeon/joblib-progress/blob/main/joblib_progress/__init__.py
33
33
  @contextmanager
34
- def JoblibProgress(
34
+ def JoblibProgress( # noqa: N802
35
35
  *columns: ProgressColumn | str,
36
36
  description: str | None = None,
37
37
  total: int | None = None,
@@ -68,7 +68,7 @@ def JoblibProgress(
68
68
  task_id = progress.add_task(description, total=total)
69
69
  print_progress = joblib.parallel.Parallel.print_progress
70
70
 
71
- def update_progress(self: joblib.parallel.Parallel):
71
+ def update_progress(self: joblib.parallel.Parallel) -> None:
72
72
  progress.update(task_id, completed=self.n_completed_tasks, refresh=True)
73
73
  return print_progress(self)
74
74
 
@@ -23,8 +23,6 @@ from dataclasses import dataclass, field
23
23
  from itertools import chain
24
24
  from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
25
25
 
26
- from mlflow.entities.run import Run
27
-
28
26
  from hydraflow.config import iter_params
29
27
  from hydraflow.info import RunCollectionInfo
30
28
 
@@ -33,6 +31,7 @@ if TYPE_CHECKING:
33
31
  from pathlib import Path
34
32
  from typing import Any
35
33
 
34
+ from mlflow.entities.run import Run
36
35
  from omegaconf import DictConfig
37
36
 
38
37
 
@@ -60,7 +59,7 @@ class RunCollection:
60
59
  _info: RunCollectionInfo = field(init=False)
61
60
  """An instance of `RunCollectionInfo`."""
62
61
 
63
- def __post_init__(self):
62
+ def __post_init__(self) -> None:
64
63
  self._info = RunCollectionInfo(self)
65
64
 
66
65
  def __repr__(self) -> str:
@@ -89,7 +88,7 @@ class RunCollection:
89
88
 
90
89
  @classmethod
91
90
  def from_list(cls, runs: list[Run]) -> RunCollection:
92
- """Create a new `RunCollection` instance from a list of MLflow `Run` instances."""
91
+ """Create a `RunCollection` instance from a list of MLflow `Run` instances."""
93
92
 
94
93
  return cls(runs)
95
94
 
@@ -120,6 +119,7 @@ class RunCollection:
120
119
  def sort(
121
120
  self,
122
121
  key: Callable[[Run], Any] | None = None,
122
+ *,
123
123
  reverse: bool = False,
124
124
  ) -> None:
125
125
  self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
@@ -393,7 +393,7 @@ class RunCollection:
393
393
  param_names = set()
394
394
 
395
395
  for run in self:
396
- for param in run.data.params.keys():
396
+ for param in run.data.params:
397
397
  param_names.add(param)
398
398
 
399
399
  return list(param_names)
@@ -537,10 +537,11 @@ class RunCollection:
537
537
  Results obtained by applying the function to each artifact directory
538
538
  in the collection.
539
539
  """
540
- return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir)
540
+ return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
541
541
 
542
542
  def group_by(
543
- self, *names: str | list[str]
543
+ self,
544
+ *names: str | list[str],
544
545
  ) -> dict[tuple[str | None, ...], RunCollection]:
545
546
  """
546
547
  Group runs by specified parameter names.
@@ -595,13 +596,19 @@ def _param_matches(run: Run, key: str, value: Any) -> bool:
595
596
  if isinstance(value, list) and value:
596
597
  return type(value[0])(param) in value
597
598
 
598
- if isinstance(value, tuple) and len(value) == 2:
599
+ if isinstance(value, tuple) and len(value) == 2: # noqa: PLR2004
599
600
  return value[0] <= type(value[0])(param) < value[1]
600
601
 
601
602
  return type(value)(param) == value
602
603
 
603
604
 
604
- def filter_runs(runs: list[Run], config: object | None = None, **kwargs) -> list[Run]:
605
+ def filter_runs(
606
+ runs: list[Run],
607
+ config: object | None = None,
608
+ *,
609
+ status: str | list[str] | None = None,
610
+ **kwargs,
611
+ ) -> list[Run]:
605
612
  """
606
613
  Filter the runs based on the provided configuration.
607
614
 
@@ -623,6 +630,7 @@ def filter_runs(runs: list[Run], config: object | None = None, **kwargs) -> list
623
630
  config (object | None): The configuration object to filter the runs.
624
631
  This can be any object that provides key-value pairs through the
625
632
  `iter_params` function.
633
+ status (str | list[str] | None): The status of the runs to filter.
626
634
  **kwargs: Additional key-value pairs to filter the runs.
627
635
 
628
636
  Returns:
@@ -634,6 +642,15 @@ def filter_runs(runs: list[Run], config: object | None = None, **kwargs) -> list
634
642
  if len(runs) == 0:
635
643
  return []
636
644
 
645
+ if isinstance(status, str) and status.startswith("!"):
646
+ status = status[1:].lower()
647
+ return [run for run in runs if run.info.status.lower() != status]
648
+
649
+ if status:
650
+ status = [status] if isinstance(status, str) else status
651
+ status = [s.lower() for s in status]
652
+ return [run for run in runs if run.info.status.lower() in status]
653
+
637
654
  return runs
638
655
 
639
656
 
File without changes
File without changes
@@ -3,12 +3,15 @@ from __future__ import annotations
3
3
  import subprocess
4
4
  import sys
5
5
  from pathlib import Path
6
+ from typing import TYPE_CHECKING
6
7
 
7
8
  import mlflow
8
9
  import pytest
9
- from omegaconf import DictConfig
10
10
 
11
- from hydraflow.run_collection import RunCollection
11
+ if TYPE_CHECKING:
12
+ from omegaconf import DictConfig
13
+
14
+ from hydraflow.run_collection import RunCollection
12
15
 
13
16
 
14
17
  @pytest.fixture
@@ -32,7 +35,7 @@ def test_list_runs_all(rc: RunCollection):
32
35
  rc_ = list_runs([])
33
36
  assert len(rc) == len(rc_)
34
37
 
35
- for a, b in zip(rc, rc_):
38
+ for a, b in zip(rc, rc_, strict=False):
36
39
  assert a.info.run_id == b.info.run_id
37
40
  assert a.info.start_time == b.info.start_time
38
41
  assert a.info.status == b.info.status
@@ -46,7 +49,7 @@ def test_list_runs_parallel(rc: RunCollection, n_jobs: int):
46
49
  rc_ = list_runs("_info_", n_jobs=n_jobs)
47
50
  assert len(rc) == len(rc_)
48
51
 
49
- for a, b in zip(rc, rc_):
52
+ for a, b in zip(rc, rc_, strict=False):
50
53
  assert a.info.run_id == b.info.run_id
51
54
  assert a.info.start_time == b.info.start_time
52
55
  assert a.info.status == b.info.status
@@ -61,7 +64,7 @@ def test_list_runs_parallel_active(rc: RunCollection, n_jobs: int):
61
64
  rc_ = list_runs(n_jobs=n_jobs)
62
65
  assert len(rc) == len(rc_)
63
66
 
64
- for a, b in zip(rc, rc_):
67
+ for a, b in zip(rc, rc_, strict=False):
65
68
  assert a.info.run_id == b.info.run_id
66
69
  assert a.info.start_time == b.info.start_time
67
70
  assert a.info.status == b.info.status
@@ -98,7 +101,6 @@ def test_app_info_config(rc: RunCollection):
98
101
 
99
102
  def test_app_info_artifact_uri(rc: RunCollection):
100
103
  uris = rc.info.artifact_uri
101
- print(uris)
102
104
  assert all(uri.startswith("file://") for uri in uris) # type: ignore
103
105
  assert all(uri.endswith("/artifacts") for uri in uris) # type: ignore
104
106
  assert all("mlruns" in uri for uri in uris) # type: ignore
@@ -5,10 +5,14 @@ import sys
5
5
  from pathlib import Path
6
6
  from threading import Thread
7
7
  from time import sleep
8
- from typing import Callable
8
+ from typing import TYPE_CHECKING
9
9
 
10
10
  import pytest
11
- from watchfiles import Change
11
+
12
+ if TYPE_CHECKING:
13
+ from collections.abc import Callable
14
+
15
+ from watchfiles import Change
12
16
 
13
17
 
14
18
  def sleep_write(path: Path):
@@ -72,7 +76,9 @@ async def test_monitor_file_changes(tmp_path: Path, write_soon: Callable[[Path],
72
76
  changes_detected.extend(changes)
73
77
 
74
78
  write_soon(tmp_path / "test.txt")
75
- monitor_task = asyncio.create_task(monitor_file_changes([tmp_path], callback, stop_event))
79
+ monitor_task = asyncio.create_task(
80
+ monitor_file_changes([tmp_path], callback, stop_event),
81
+ )
76
82
 
77
83
  await asyncio.sleep(1)
78
84
  stop_event.set()
@@ -135,7 +135,11 @@ def test_iter_params_from_config(cfg):
135
135
  def test_iter_params_with_empty_config():
136
136
  from hydraflow.config import iter_params
137
137
 
138
- empty_cfg = Config(size=Size(x=0, y=0), db=Db(name="", port=0), store=Store(items=[]))
138
+ empty_cfg = Config(
139
+ size=Size(x=0, y=0),
140
+ db=Db(name="", port=0),
141
+ store=Store(items=[]),
142
+ )
139
143
  it = iter_params(empty_cfg)
140
144
  assert next(it) == ("size.x", 0)
141
145
  assert next(it) == ("size.y", 0)
@@ -46,8 +46,7 @@ def test_output(run_id: str):
46
46
 
47
47
  def read_log(run_id: str, path: str) -> str:
48
48
  path = download_artifacts(run_id=run_id, artifact_path=path)
49
- text = Path(path).read_text()
50
- return text
49
+ return Path(path).read_text()
51
50
 
52
51
 
53
52
  def test_load_config(run: Run):
@@ -5,8 +5,8 @@ import pytest
5
5
 
6
6
 
7
7
  @pytest.mark.skipif(
8
- sys.platform == "win32", reason="'cp932' codec can't encode character '\\u2807'"
8
+ sys.platform == "win32", reason="'cp932' codec can't encode character '\\u2807'",
9
9
  )
10
10
  def test_progress_bar():
11
- cp = run([sys.executable, "tests/scripts/progress.py"])
11
+ cp = run([sys.executable, "tests/scripts/progress.py"], check=False)
12
12
  assert cp.returncode == 0
@@ -90,6 +90,16 @@ def test_filter_invalid_param(run_list: list[Run]):
90
90
  assert len(x) == 6
91
91
 
92
92
 
93
+ def test_filter_status(run_list: list[Run]):
94
+ from hydraflow.run_collection import filter_runs
95
+
96
+ assert not filter_runs(run_list, status="RUNNING")
97
+ assert filter_runs(run_list, status="finished") == run_list
98
+ assert filter_runs(run_list, status=["finished", "running"]) == run_list
99
+ assert filter_runs(run_list, status="!RUNNING") == run_list
100
+ assert not filter_runs(run_list, status="!finished")
101
+
102
+
93
103
  def test_get_params(run_list: list[Run]):
94
104
  from hydraflow.run_collection import get_params
95
105
 
@@ -162,8 +172,6 @@ def test_runs_filter(runs: RunCollection):
162
172
 
163
173
 
164
174
  def test_runs_get(runs: RunCollection):
165
- from hydraflow.run_collection import Run
166
-
167
175
  run = runs.get({"p": 4})
168
176
  assert isinstance(run, Run)
169
177
  run = runs.get(p=2)
@@ -195,8 +203,6 @@ def test_runs_get_params_dict(runs: RunCollection):
195
203
 
196
204
 
197
205
  def test_runs_find(runs: RunCollection):
198
- from hydraflow.run_collection import Run
199
-
200
206
  run = runs.find({"r": 0})
201
207
  assert isinstance(run, Run)
202
208
  assert run.data.params["p"] == "0"
@@ -216,8 +222,6 @@ def test_runs_try_find_none(runs: RunCollection):
216
222
 
217
223
 
218
224
  def test_runs_find_last(runs: RunCollection):
219
- from hydraflow.run_collection import Run
220
-
221
225
  run = runs.find_last({"r": 0})
222
226
  assert isinstance(run, Run)
223
227
  assert run.data.params["p"] == "3"
@@ -303,7 +307,7 @@ def test_run_collection_map_run_id_kwargs(runs: RunCollection):
303
307
  def test_run_collection_map_uri(runs: RunCollection):
304
308
  results = list(runs.map_uri(lambda uri: uri))
305
309
  assert len(results) == len(runs._runs)
306
- assert all(isinstance(uri, (str, type(None))) for uri in results)
310
+ assert all(isinstance(uri, str | type(None)) for uri in results)
307
311
 
308
312
 
309
313
  def test_run_collection_map_dir(runs: RunCollection):
@@ -8,7 +8,7 @@ import pytest
8
8
 
9
9
 
10
10
  @pytest.mark.parametrize("dir", [".", Path])
11
- def test_watch(dir, monkeypatch, tmp_path):
11
+ def test_watch(dir, monkeypatch, tmp_path): # noqa: A002
12
12
  from hydraflow.context import watch
13
13
 
14
14
  file = Path("tests/scripts/watch.py").absolute()
File without changes
File without changes
File without changes
File without changes