hydraflow 0.2.15__tar.gz → 0.2.17__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (37) hide show
  1. {hydraflow-0.2.15 → hydraflow-0.2.17}/.gitignore +2 -1
  2. {hydraflow-0.2.15 → hydraflow-0.2.17}/PKG-INFO +3 -10
  3. {hydraflow-0.2.15 → hydraflow-0.2.17}/pyproject.toml +28 -13
  4. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/asyncio.py +39 -19
  5. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/config.py +3 -3
  6. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/context.py +28 -20
  7. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/info.py +1 -1
  8. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/mlflow.py +4 -2
  9. hydraflow-0.2.17/src/hydraflow/param.py +64 -0
  10. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/progress.py +2 -2
  11. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/run_collection.py +18 -35
  12. hydraflow-0.2.17/tests/__init__.py +0 -0
  13. hydraflow-0.2.17/tests/scripts/__init__.py +0 -0
  14. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/scripts/app.py +2 -1
  15. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_app.py +30 -14
  16. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_asyncio.py +9 -3
  17. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_config.py +27 -2
  18. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_log_run.py +1 -2
  19. hydraflow-0.2.17/tests/test_param.py +78 -0
  20. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_progress.py +2 -2
  21. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_run_collection.py +131 -124
  22. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_watch.py +1 -1
  23. {hydraflow-0.2.15 → hydraflow-0.2.17}/.devcontainer/devcontainer.json +0 -0
  24. {hydraflow-0.2.15 → hydraflow-0.2.17}/.devcontainer/postCreate.sh +0 -0
  25. {hydraflow-0.2.15 → hydraflow-0.2.17}/.devcontainer/starship.toml +0 -0
  26. {hydraflow-0.2.15 → hydraflow-0.2.17}/.gitattributes +0 -0
  27. {hydraflow-0.2.15 → hydraflow-0.2.17}/LICENSE +0 -0
  28. {hydraflow-0.2.15 → hydraflow-0.2.17}/README.md +0 -0
  29. {hydraflow-0.2.15 → hydraflow-0.2.17}/mkdocs.yml +0 -0
  30. {hydraflow-0.2.15 → hydraflow-0.2.17}/src/hydraflow/__init__.py +0 -0
  31. /hydraflow-0.2.15/tests/scripts/__init__.py → /hydraflow-0.2.17/src/hydraflow/py.typed +0 -0
  32. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/scripts/progress.py +0 -0
  33. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/scripts/watch.py +0 -0
  34. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_context.py +0 -0
  35. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_info.py +0 -0
  36. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_mlflow.py +0 -0
  37. {hydraflow-0.2.15 → hydraflow-0.2.17}/tests/test_version.py +0 -0
@@ -3,4 +3,5 @@
3
3
  .venv/
4
4
  __pycache__/
5
5
  dist/
6
- lcov.info
6
+ lcov.info
7
+ uv.lock
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.15
3
+ Version: 0.2.17
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -14,19 +14,12 @@ Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
15
  Classifier: Programming Language :: Python :: 3.12
16
16
  Requires-Python: >=3.10
17
- Requires-Dist: hydra-core>1.3
17
+ Requires-Dist: hydra-core>=1.3
18
18
  Requires-Dist: joblib
19
- Requires-Dist: mlflow>2.15
19
+ Requires-Dist: mlflow>=2.15
20
20
  Requires-Dist: rich
21
- Requires-Dist: setuptools
22
21
  Requires-Dist: watchdog
23
22
  Requires-Dist: watchfiles
24
- Provides-Extra: dev
25
- Requires-Dist: pytest-asyncio; extra == 'dev'
26
- Requires-Dist: pytest-clarity; extra == 'dev'
27
- Requires-Dist: pytest-cov; extra == 'dev'
28
- Requires-Dist: pytest-randomly; extra == 'dev'
29
- Requires-Dist: pytest-xdist; extra == 'dev'
30
23
  Description-Content-Type: text/markdown
31
24
 
32
25
  # Hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.15"
7
+ version = "0.2.17"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -18,29 +18,29 @@ classifiers = [
18
18
  ]
19
19
  requires-python = ">=3.10"
20
20
  dependencies = [
21
- "hydra-core>1.3",
21
+ "hydra-core>=1.3",
22
22
  "joblib",
23
- "mlflow>2.15",
23
+ "mlflow>=2.15",
24
24
  "rich",
25
- "setuptools",
26
25
  "watchdog",
27
26
  "watchfiles",
28
27
  ]
29
28
 
30
- [project.optional-dependencies]
31
- dev = [
29
+ [project.urls]
30
+ Documentation = "https://github.com/daizutabi/hydraflow"
31
+ Source = "https://github.com/daizutabi/hydraflow"
32
+ Issues = "https://github.com/daizutabi/hydraflow/issues"
33
+
34
+ [tool.uv]
35
+ dev-dependencies = [
32
36
  "pytest-asyncio",
33
37
  "pytest-clarity",
34
38
  "pytest-cov",
35
39
  "pytest-randomly",
36
40
  "pytest-xdist",
41
+ "ruff",
37
42
  ]
38
43
 
39
- [project.urls]
40
- Documentation = "https://github.com/daizutabi/hydraflow"
41
- Source = "https://github.com/daizutabi/hydraflow"
42
- Issues = "https://github.com/daizutabi/hydraflow/issues"
43
-
44
44
  [tool.hatch.build.targets.sdist]
45
45
  exclude = ["/.github", "/docs"]
46
46
 
@@ -62,7 +62,22 @@ exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
62
62
 
63
63
  [tool.ruff]
64
64
  line-length = 88
65
- target-version = "py312"
65
+ target-version = "py310"
66
66
 
67
67
  [tool.ruff.lint]
68
- unfixable = ["F401", "RUF100"]
68
+ select = ["ALL"]
69
+ unfixable = ["F401"]
70
+ ignore = [
71
+ "ANN003",
72
+ "ANN401",
73
+ "ARG002",
74
+ "B904",
75
+ "D",
76
+ "EM101",
77
+ "PGH003",
78
+ "TRY003",
79
+ ]
80
+ exclude = ["tests/scripts/*.py"]
81
+
82
+ [tool.ruff.lint.per-file-ignores]
83
+ "tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM117", "SLF"]
@@ -42,7 +42,10 @@ async def execute_command(
42
42
  """
43
43
  try:
44
44
  process = await asyncio.create_subprocess_exec(
45
- program, *args, stdout=PIPE, stderr=PIPE
45
+ program,
46
+ *args,
47
+ stdout=PIPE,
48
+ stderr=PIPE,
46
49
  )
47
50
  await asyncio.gather(
48
51
  process_stream(process.stdout, stdout),
@@ -51,7 +54,8 @@ async def execute_command(
51
54
  returncode = await process.wait()
52
55
 
53
56
  except Exception as e:
54
- logger.error(f"Error running command: {e}")
57
+ msg = f"Error running command: {e}"
58
+ logger.exception(msg)
55
59
  returncode = 1
56
60
 
57
61
  finally:
@@ -103,11 +107,15 @@ async def monitor_file_changes(
103
107
  str_paths = [str(path) for path in paths]
104
108
  try:
105
109
  async for changes in watchfiles.awatch(
106
- *str_paths, stop_event=stop_event, **awatch_kwargs
110
+ *str_paths,
111
+ stop_event=stop_event,
112
+ **awatch_kwargs,
107
113
  ):
108
114
  callback(changes)
109
115
  except Exception as e:
110
- logger.error(f"Error watching files: {e}")
116
+ msg = f"Error watching files: {e}"
117
+ logger.exception(msg)
118
+ raise
111
119
 
112
120
 
113
121
  async def run_and_monitor(
@@ -134,13 +142,16 @@ async def run_and_monitor(
134
142
  stop_event = asyncio.Event()
135
143
  run_task = asyncio.create_task(
136
144
  execute_command(
137
- program, *args, stop_event=stop_event, stdout=stdout, stderr=stderr
138
- )
145
+ program,
146
+ *args,
147
+ stop_event=stop_event,
148
+ stdout=stdout,
149
+ stderr=stderr,
150
+ ),
139
151
  )
140
152
  if watch and paths:
141
- monitor_task = asyncio.create_task(
142
- monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
143
- )
153
+ coro = monitor_file_changes(paths, watch, stop_event, **awatch_kwargs)
154
+ monitor_task = asyncio.create_task(coro)
144
155
  else:
145
156
  monitor_task = None
146
157
 
@@ -151,7 +162,10 @@ async def run_and_monitor(
151
162
  await run_task
152
163
 
153
164
  except Exception as e:
154
- logger.error(f"Error in run_and_monitor: {e}")
165
+ msg = f"Error in run_and_monitor: {e}"
166
+ logger.exception(msg)
167
+ raise
168
+
155
169
  finally:
156
170
  stop_event.set()
157
171
  await run_task
@@ -173,18 +187,24 @@ def run(
173
187
  """
174
188
  Run a command synchronously and optionally watch for file changes.
175
189
 
176
- This function is a synchronous wrapper around the asynchronous `run_and_monitor` function.
177
- It runs a specified command and optionally monitors specified paths for file changes,
178
- invoking the provided callbacks for standard output, standard error, and file changes.
190
+ This function is a synchronous wrapper around the asynchronous
191
+ `run_and_monitor` function. It runs a specified command and optionally
192
+ monitors specified paths for file changes, invoking the provided callbacks for
193
+ standard output, standard error, and file changes.
179
194
 
180
195
  Args:
181
196
  program (str): The program to run.
182
197
  *args (str): Arguments for the program.
183
- stdout (Callable[[str], None] | None): Callback for handling standard output lines.
184
- stderr (Callable[[str], None] | None): Callback for handling standard error lines.
185
- watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for handling file changes.
186
- paths (list[str | Path] | None): List of paths to monitor for file changes.
187
- **awatch_kwargs: Additional keyword arguments to pass to `watchfiles.awatch`.
198
+ stdout (Callable[[str], None] | None): Callback for handling standard
199
+ output lines.
200
+ stderr (Callable[[str], None] | None): Callback for handling standard
201
+ error lines.
202
+ watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for
203
+ handling file changes.
204
+ paths (list[str | Path] | None): List of paths to monitor for file
205
+ changes.
206
+ **awatch_kwargs: Additional keyword arguments to pass to
207
+ `watchfiles.awatch`.
188
208
 
189
209
  Returns:
190
210
  int: The return code of the process.
@@ -201,5 +221,5 @@ def run(
201
221
  watch=watch,
202
222
  paths=paths,
203
223
  **awatch_kwargs,
204
- )
224
+ ),
205
225
  )
@@ -33,7 +33,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
33
33
  if config is None:
34
34
  return
35
35
 
36
- if not isinstance(config, (DictConfig, ListConfig)):
36
+ if not isinstance(config, DictConfig | ListConfig):
37
37
  config = OmegaConf.create(config) # type: ignore
38
38
 
39
39
  yield from _iter_params(config, prefix)
@@ -62,8 +62,8 @@ def _is_param(value: object) -> bool:
62
62
  if isinstance(value, DictConfig):
63
63
  return False
64
64
 
65
- if isinstance(value, ListConfig):
66
- if any(isinstance(v, (DictConfig, ListConfig)) for v in value):
65
+ if isinstance(value, ListConfig): # noqa: SIM102
66
+ if any(isinstance(v, DictConfig | ListConfig) for v in value):
67
67
  return False
68
68
 
69
69
  return True
@@ -75,7 +75,8 @@ def log_run(
75
75
  yield
76
76
 
77
77
  except Exception as e:
78
- log.error(f"Error during log_run: {e}")
78
+ msg = f"Error during log_run: {e}"
79
+ log.exception(msg)
79
80
  raise
80
81
 
81
82
  finally:
@@ -84,7 +85,7 @@ def log_run(
84
85
 
85
86
 
86
87
  @contextmanager
87
- def start_run(
88
+ def start_run( # noqa: PLR0913
88
89
  config: object,
89
90
  *,
90
91
  run_id: str | None = None,
@@ -112,8 +113,10 @@ def start_run(
112
113
  parent_run_id (str | None): The parent run ID. Defaults to None.
113
114
  tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
114
115
  description (str | None): A description of the run. Defaults to None.
115
- log_system_metrics (bool | None): Whether to log system metrics. Defaults to None.
116
- synchronous (bool | None): Whether to log parameters synchronously. Defaults to None.
116
+ log_system_metrics (bool | None): Whether to log system metrics.
117
+ Defaults to None.
118
+ synchronous (bool | None): Whether to log parameters synchronously.
119
+ Defaults to None.
117
120
 
118
121
  Yields:
119
122
  Run: An MLflow Run object representing the started run.
@@ -128,24 +131,27 @@ def start_run(
128
131
  - `log_run`: A context manager to log parameters and manage the MLflow
129
132
  run context.
130
133
  """
131
- with mlflow.start_run(
132
- run_id=run_id,
133
- experiment_id=experiment_id,
134
- run_name=run_name,
135
- nested=nested,
136
- parent_run_id=parent_run_id,
137
- tags=tags,
138
- description=description,
139
- log_system_metrics=log_system_metrics,
140
- ) as run:
141
- with log_run(config, synchronous=synchronous):
142
- yield run
134
+ with (
135
+ mlflow.start_run(
136
+ run_id=run_id,
137
+ experiment_id=experiment_id,
138
+ run_name=run_name,
139
+ nested=nested,
140
+ parent_run_id=parent_run_id,
141
+ tags=tags,
142
+ description=description,
143
+ log_system_metrics=log_system_metrics,
144
+ ) as run,
145
+ log_run(config, synchronous=synchronous),
146
+ ):
147
+ yield run
143
148
 
144
149
 
145
150
  @contextmanager
146
151
  def watch(
147
152
  callback: Callable[[Path], None],
148
- dir: Path | str = "",
153
+ dir: Path | str = "", # noqa: A002
154
+ *,
149
155
  timeout: int = 60,
150
156
  ignore_patterns: list[str] | None = None,
151
157
  ignore_log: bool = True,
@@ -178,9 +184,9 @@ def watch(
178
184
  pass
179
185
  ```
180
186
  """
181
- dir = dir or get_artifact_dir()
187
+ dir = dir or get_artifact_dir() # noqa: A001
182
188
  if isinstance(dir, Path):
183
- dir = dir.as_posix()
189
+ dir = dir.as_posix() # noqa: A001
184
190
 
185
191
  handler = Handler(callback, ignore_patterns=ignore_patterns, ignore_log=ignore_log)
186
192
  observer = Observer()
@@ -191,7 +197,8 @@ def watch(
191
197
  yield
192
198
 
193
199
  except Exception as e:
194
- log.error(f"Error during watch: {e}")
200
+ msg = f"Error during watch: {e}"
201
+ log.exception(msg)
195
202
  raise
196
203
 
197
204
  finally:
@@ -210,6 +217,7 @@ class Handler(PatternMatchingEventHandler):
210
217
  def __init__(
211
218
  self,
212
219
  func: Callable[[Path], None],
220
+ *,
213
221
  ignore_patterns: list[str] | None = None,
214
222
  ignore_log: bool = True,
215
223
  ) -> None:
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
15
15
 
16
16
 
17
17
  class RunCollectionInfo:
18
- def __init__(self, runs: RunCollection):
18
+ def __init__(self, runs: RunCollection) -> None:
19
19
  self._runs = runs
20
20
 
21
21
  @property
@@ -81,7 +81,8 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
81
81
  mlflow.log_param(key, value, synchronous=synchronous)
82
82
 
83
83
 
84
- def search_runs(
84
+ def search_runs( # noqa: PLR0913
85
+ *,
85
86
  experiment_ids: list[str] | None = None,
86
87
  filter_string: str = "",
87
88
  run_view_type: int = ViewType.ACTIVE_ONLY,
@@ -148,7 +149,8 @@ def search_runs(
148
149
 
149
150
 
150
151
  def list_runs(
151
- experiment_names: str | list[str] | None = None, n_jobs: int = 0
152
+ experiment_names: str | list[str] | None = None,
153
+ n_jobs: int = 0,
152
154
  ) -> RunCollection:
153
155
  """
154
156
  List all runs for the specified experiments.
@@ -0,0 +1,64 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+
6
+ def match(param: str, value: Any) -> bool:
7
+ """Check if the string matches the specified value.
8
+
9
+ Args:
10
+ param (str): The parameter to check.
11
+ value (Any): The value to check.
12
+
13
+ Returns:
14
+ True if the parameter matches the specified value,
15
+ False otherwise.
16
+ """
17
+ if value in [None, True, False]:
18
+ return param == str(value)
19
+
20
+ if isinstance(value, list) and (m := _match_list(param, value)) is not None:
21
+ return m
22
+
23
+ if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
24
+ return m
25
+
26
+ if isinstance(value, int | float | str):
27
+ return type(value)(param) == value
28
+
29
+ return param == str(value)
30
+
31
+
32
+ def _match_list(param: str, value: list) -> bool | None:
33
+ if not value:
34
+ return None
35
+
36
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
37
+ return None
38
+
39
+ if isinstance(value[0], bool):
40
+ return None
41
+
42
+ if not isinstance(value[0], int | float | str):
43
+ return None
44
+
45
+ return type(value[0])(param) in value
46
+
47
+
48
+ def _match_tuple(param: str, value: tuple) -> bool | None:
49
+ if len(value) != 2: # noqa: PLR2004
50
+ return None
51
+
52
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
53
+ return None
54
+
55
+ if isinstance(value[0], bool):
56
+ return None
57
+
58
+ if not isinstance(value[0], int | float | str):
59
+ return None
60
+
61
+ if type(value[0]) is not type(value[1]):
62
+ return None
63
+
64
+ return value[0] <= type(value[0])(param) < value[1] # type: ignore
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
31
31
 
32
32
  # https://github.com/jonghwanhyeon/joblib-progress/blob/main/joblib_progress/__init__.py
33
33
  @contextmanager
34
- def JoblibProgress(
34
+ def JoblibProgress( # noqa: N802
35
35
  *columns: ProgressColumn | str,
36
36
  description: str | None = None,
37
37
  total: int | None = None,
@@ -68,7 +68,7 @@ def JoblibProgress(
68
68
  task_id = progress.add_task(description, total=total)
69
69
  print_progress = joblib.parallel.Parallel.print_progress
70
70
 
71
- def update_progress(self: joblib.parallel.Parallel):
71
+ def update_progress(self: joblib.parallel.Parallel) -> None:
72
72
  progress.update(task_id, completed=self.n_completed_tasks, refresh=True)
73
73
  return print_progress(self)
74
74
 
@@ -23,8 +23,7 @@ from dataclasses import dataclass, field
23
23
  from itertools import chain
24
24
  from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
25
25
 
26
- from mlflow.entities.run import Run
27
-
26
+ import hydraflow.param
28
27
  from hydraflow.config import iter_params
29
28
  from hydraflow.info import RunCollectionInfo
30
29
 
@@ -33,6 +32,7 @@ if TYPE_CHECKING:
33
32
  from pathlib import Path
34
33
  from typing import Any
35
34
 
35
+ from mlflow.entities.run import Run
36
36
  from omegaconf import DictConfig
37
37
 
38
38
 
@@ -60,7 +60,7 @@ class RunCollection:
60
60
  _info: RunCollectionInfo = field(init=False)
61
61
  """An instance of `RunCollectionInfo`."""
62
62
 
63
- def __post_init__(self):
63
+ def __post_init__(self) -> None:
64
64
  self._info = RunCollectionInfo(self)
65
65
 
66
66
  def __repr__(self) -> str:
@@ -87,9 +87,12 @@ class RunCollection:
87
87
  def __contains__(self, run: Run) -> bool:
88
88
  return run in self._runs
89
89
 
90
+ def __bool__(self) -> bool:
91
+ return bool(self._runs)
92
+
90
93
  @classmethod
91
94
  def from_list(cls, runs: list[Run]) -> RunCollection:
92
- """Create a new `RunCollection` instance from a list of MLflow `Run` instances."""
95
+ """Create a `RunCollection` instance from a list of MLflow `Run` instances."""
93
96
 
94
97
  return cls(runs)
95
98
 
@@ -120,6 +123,7 @@ class RunCollection:
120
123
  def sort(
121
124
  self,
122
125
  key: Callable[[Run], Any] | None = None,
126
+ *,
123
127
  reverse: bool = False,
124
128
  ) -> None:
125
129
  self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
@@ -393,7 +397,7 @@ class RunCollection:
393
397
  param_names = set()
394
398
 
395
399
  for run in self:
396
- for param in run.data.params.keys():
400
+ for param in run.data.params:
397
401
  param_names.add(param)
398
402
 
399
403
  return list(param_names)
@@ -537,10 +541,11 @@ class RunCollection:
537
541
  Results obtained by applying the function to each artifact directory
538
542
  in the collection.
539
543
  """
540
- return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir)
544
+ return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
541
545
 
542
546
  def group_by(
543
- self, *names: str | list[str]
547
+ self,
548
+ *names: str | list[str],
544
549
  ) -> dict[tuple[str | None, ...], RunCollection]:
545
550
  """
546
551
  Group runs by specified parameter names.
@@ -568,37 +573,15 @@ class RunCollection:
568
573
 
569
574
 
570
575
  def _param_matches(run: Run, key: str, value: Any) -> bool:
571
- """
572
- Check if the run's parameter matches the specified key-value pair.
573
-
574
- Check if the run's parameters contain the specified
575
- key-value pair. It handles different types of values, including lists
576
- and tuples.
577
-
578
- Args:
579
- run (Run): The run object to check.
580
- key (str): The parameter key to check.
581
- value (Any): The parameter value to check.
582
-
583
- Returns:
584
- True if the run's parameter matches the specified key-value pair,
585
- False otherwise.
586
- """
587
- param = run.data.params.get(key, value)
588
-
589
- if param is None:
590
- return False
576
+ params = run.data.params
577
+ if key not in params:
578
+ return True
591
579
 
580
+ param = params[key]
592
581
  if param == "None":
593
- return value is None
594
-
595
- if isinstance(value, list) and value:
596
- return type(value[0])(param) in value
597
-
598
- if isinstance(value, tuple) and len(value) == 2:
599
- return value[0] <= type(value[0])(param) < value[1]
582
+ return value is None or value == "None"
600
583
 
601
- return type(value)(param) == value
584
+ return hydraflow.param.match(param, value)
602
585
 
603
586
 
604
587
  def filter_runs(
File without changes
File without changes
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import time
5
- from dataclasses import dataclass
5
+ from dataclasses import dataclass, field
6
6
  from pathlib import Path
7
7
 
8
8
  import hydra
@@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
18
18
  class MySQLConfig:
19
19
  host: str = "localhost"
20
20
  port: int = 3306
21
+ values: list[int] = field(default_factory=lambda: [1, 2, 3])
21
22
 
22
23
 
23
24
  cs = ConfigStore.instance()