hydraflow 0.2.16__tar.gz → 0.2.18__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (37) hide show
  1. {hydraflow-0.2.16 → hydraflow-0.2.18}/PKG-INFO +1 -1
  2. {hydraflow-0.2.16 → hydraflow-0.2.18}/pyproject.toml +19 -4
  3. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/asyncio.py +13 -11
  4. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/config.py +3 -6
  5. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/context.py +15 -15
  6. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/info.py +16 -6
  7. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/mlflow.py +36 -23
  8. hydraflow-0.2.18/src/hydraflow/param.py +75 -0
  9. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/progress.py +7 -18
  10. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/run_collection.py +122 -99
  11. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/scripts/app.py +2 -1
  12. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_app.py +36 -8
  13. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_config.py +22 -1
  14. hydraflow-0.2.18/tests/test_param.py +78 -0
  15. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_run_collection.py +149 -118
  16. {hydraflow-0.2.16 → hydraflow-0.2.18}/.devcontainer/devcontainer.json +0 -0
  17. {hydraflow-0.2.16 → hydraflow-0.2.18}/.devcontainer/postCreate.sh +0 -0
  18. {hydraflow-0.2.16 → hydraflow-0.2.18}/.devcontainer/starship.toml +0 -0
  19. {hydraflow-0.2.16 → hydraflow-0.2.18}/.gitattributes +0 -0
  20. {hydraflow-0.2.16 → hydraflow-0.2.18}/.gitignore +0 -0
  21. {hydraflow-0.2.16 → hydraflow-0.2.18}/LICENSE +0 -0
  22. {hydraflow-0.2.16 → hydraflow-0.2.18}/README.md +0 -0
  23. {hydraflow-0.2.16 → hydraflow-0.2.18}/mkdocs.yml +0 -0
  24. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/__init__.py +0 -0
  25. {hydraflow-0.2.16 → hydraflow-0.2.18}/src/hydraflow/py.typed +0 -0
  26. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/__init__.py +0 -0
  27. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/scripts/__init__.py +0 -0
  28. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/scripts/progress.py +0 -0
  29. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/scripts/watch.py +0 -0
  30. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_asyncio.py +0 -0
  31. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_context.py +0 -0
  32. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_info.py +0 -0
  33. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_log_run.py +0 -0
  34. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_mlflow.py +0 -0
  35. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_progress.py +0 -0
  36. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_version.py +0 -0
  37. {hydraflow-0.2.16 → hydraflow-0.2.18}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.16
3
+ Version: 0.2.18
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.16"
7
+ version = "0.2.18"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -66,17 +66,32 @@ target-version = "py310"
66
66
 
67
67
  [tool.ruff.lint]
68
68
  select = ["ALL"]
69
+ unfixable = ["F401"]
69
70
  ignore = [
70
71
  "ANN003",
71
72
  "ANN401",
72
73
  "ARG002",
73
74
  "B904",
74
- "D",
75
+ "D105",
76
+ "D107",
77
+ "D203",
78
+ "D213",
75
79
  "EM101",
76
80
  "PGH003",
77
81
  "TRY003",
78
82
  ]
79
- exclude = ["tests/scripts/*.py"]
83
+ exclude = ["tests/scripts/*.py", "src/hydraflow/__init__.py"]
80
84
 
81
85
  [tool.ruff.lint.per-file-ignores]
82
- "tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM117", "SLF"]
86
+ "tests/*" = [
87
+ "A001",
88
+ "ANN",
89
+ "ARG",
90
+ "D",
91
+ "FBT",
92
+ "PLR",
93
+ "PT",
94
+ "S",
95
+ "SIM117",
96
+ "SLF",
97
+ ]
@@ -1,3 +1,5 @@
1
+ """Provide functionality for running commands and monitoring file changes."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  import asyncio
@@ -27,8 +29,7 @@ async def execute_command(
27
29
  stderr: Callable[[str], None] | None = None,
28
30
  stop_event: asyncio.Event,
29
31
  ) -> int:
30
- """
31
- Runs a command asynchronously and pass the output to callback functions.
32
+ """Run a command asynchronously and pass the output to callback functions.
32
33
 
33
34
  Args:
34
35
  program (str): The program to run.
@@ -39,6 +40,7 @@ async def execute_command(
39
40
 
40
41
  Returns:
41
42
  int: The return code of the process.
43
+
42
44
  """
43
45
  try:
44
46
  process = await asyncio.create_subprocess_exec(
@@ -68,13 +70,13 @@ async def process_stream(
68
70
  stream: StreamReader | None,
69
71
  callback: Callable[[str], None] | None,
70
72
  ) -> None:
71
- """
72
- Reads a stream asynchronously and pass each line to a callback function.
73
+ """Read a stream asynchronously and pass each line to a callback function.
73
74
 
74
75
  Args:
75
76
  stream (StreamReader | None): The stream to read from.
76
77
  callback (Callable[[str], None] | None): The callback function to handle
77
78
  each line.
79
+
78
80
  """
79
81
  if stream is None or callback is None:
80
82
  return
@@ -93,9 +95,7 @@ async def monitor_file_changes(
93
95
  stop_event: asyncio.Event,
94
96
  **awatch_kwargs,
95
97
  ) -> None:
96
- """
97
- Watches for file changes in specified paths and pass the changes to a
98
- callback function.
98
+ """Watch file changes in specified paths and pass the changes to a callback.
99
99
 
100
100
  Args:
101
101
  paths (list[str | Path]): List of paths to monitor for changes.
@@ -103,6 +103,7 @@ async def monitor_file_changes(
103
103
  function to handle file changes.
104
104
  stop_event (asyncio.Event): Event to signal when to stop watching.
105
105
  **awatch_kwargs: Additional keyword arguments to pass to watchfiles.awatch.
106
+
106
107
  """
107
108
  str_paths = [str(path) for path in paths]
108
109
  try:
@@ -127,8 +128,7 @@ async def run_and_monitor(
127
128
  paths: list[str | Path] | None = None,
128
129
  **awatch_kwargs,
129
130
  ) -> int:
130
- """
131
- Runs a command and optionally watch for file changes concurrently.
131
+ """Run a command and optionally watch for file changes concurrently.
132
132
 
133
133
  Args:
134
134
  program (str): The program to run.
@@ -138,6 +138,8 @@ async def run_and_monitor(
138
138
  watch (Callable[[set[tuple[Change, str]]], None] | None): Callback for
139
139
  file changes.
140
140
  paths (list[str | Path] | None): List of paths to monitor for changes.
141
+ **awatch_kwargs: Additional keyword arguments to pass to `watchfiles.awatch`.
142
+
141
143
  """
142
144
  stop_event = asyncio.Event()
143
145
  run_task = asyncio.create_task(
@@ -184,8 +186,7 @@ def run(
184
186
  paths: list[str | Path] | None = None,
185
187
  **awatch_kwargs,
186
188
  ) -> int:
187
- """
188
- Run a command synchronously and optionally watch for file changes.
189
+ """Run a command synchronously and optionally watch for file changes.
189
190
 
190
191
  This function is a synchronous wrapper around the asynchronous
191
192
  `run_and_monitor` function. It runs a specified command and optionally
@@ -208,6 +209,7 @@ def run(
208
209
 
209
210
  Returns:
210
211
  int: The return code of the process.
212
+
211
213
  """
212
214
  if watch and not paths:
213
215
  paths = [Path.cwd()]
@@ -1,7 +1,4 @@
1
- """
2
- This module provides functionality for working with configuration
3
- objects using the OmegaConf library.
4
- """
1
+ """Provide functionality for working with configuration objects using the OmegaConf."""
5
2
 
6
3
  from __future__ import annotations
7
4
 
@@ -15,8 +12,7 @@ if TYPE_CHECKING:
15
12
 
16
13
 
17
14
  def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
18
- """
19
- Recursively iterate over the parameters in the given configuration object.
15
+ """Recursively iterate over the parameters in the given configuration object.
20
16
 
21
17
  This function traverses the configuration object and yields key-value pairs
22
18
  representing the parameters. The keys are prefixed with the provided prefix.
@@ -29,6 +25,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
29
25
 
30
26
  Yields:
31
27
  Key-value pairs representing the parameters in the configuration object.
28
+
32
29
  """
33
30
  if config is None:
34
31
  return
@@ -1,7 +1,4 @@
1
- """
2
- This module provides context managers to log parameters and manage the MLflow
3
- run context.
4
- """
1
+ """Provide context managers to log parameters and manage the MLflow run context."""
5
2
 
6
3
  from __future__ import annotations
7
4
 
@@ -34,9 +31,7 @@ def log_run(
34
31
  *,
35
32
  synchronous: bool | None = None,
36
33
  ) -> Iterator[None]:
37
- """
38
- Log the parameters from the given configuration object and manage the MLflow
39
- run context.
34
+ """Log the parameters from the given configuration object.
40
35
 
41
36
  This context manager logs the parameters from the provided configuration object
42
37
  using MLflow. It also manages the MLflow run context, ensuring that artifacts
@@ -56,6 +51,7 @@ def log_run(
56
51
  # Perform operations within the MLflow run context
57
52
  pass
58
53
  ```
54
+
59
55
  """
60
56
  log_params(config, synchronous=synchronous)
61
57
 
@@ -98,8 +94,7 @@ def start_run( # noqa: PLR0913
98
94
  log_system_metrics: bool | None = None,
99
95
  synchronous: bool | None = None,
100
96
  ) -> Iterator[Run]:
101
- """
102
- Start an MLflow run and log parameters using the provided configuration object.
97
+ """Start an MLflow run and log parameters using the provided configuration object.
103
98
 
104
99
  This context manager starts an MLflow run and logs parameters using the specified
105
100
  configuration object. It ensures that the run is properly closed after completion.
@@ -130,6 +125,7 @@ def start_run( # noqa: PLR0913
130
125
  - `mlflow.start_run`: The MLflow function to start a run directly.
131
126
  - `log_run`: A context manager to log parameters and manage the MLflow
132
127
  run context.
128
+
133
129
  """
134
130
  with (
135
131
  mlflow.start_run(
@@ -156,9 +152,7 @@ def watch(
156
152
  ignore_patterns: list[str] | None = None,
157
153
  ignore_log: bool = True,
158
154
  ) -> Iterator[None]:
159
- """
160
- Watch the given directory for changes and call the provided function
161
- when a change is detected.
155
+ """Watch the given directory for changes.
162
156
 
163
157
  This context manager sets up a file system watcher on the specified directory.
164
158
  When a file modification is detected, the provided function is called with
@@ -173,6 +167,9 @@ def watch(
173
167
  the current MLflow artifact URI is used. Defaults to "".
174
168
  timeout (int): The timeout period in seconds for the watcher
175
169
  to run after the context is exited. Defaults to 60.
170
+ ignore_patterns (list[str] | None): A list of glob patterns to ignore.
171
+ Defaults to None.
172
+ ignore_log (bool): Whether to ignore log files. Defaults to True.
176
173
 
177
174
  Yields:
178
175
  None
@@ -183,6 +180,7 @@ def watch(
183
180
  # Perform operations while watching the directory for changes
184
181
  pass
185
182
  ```
183
+
186
184
  """
187
185
  dir = dir or get_artifact_dir() # noqa: A001
188
186
  if isinstance(dir, Path):
@@ -214,6 +212,8 @@ def watch(
214
212
 
215
213
 
216
214
  class Handler(PatternMatchingEventHandler):
215
+ """Monitor file changes and call the given function when a change is detected."""
216
+
217
217
  def __init__(
218
218
  self,
219
219
  func: Callable[[Path], None],
@@ -232,6 +232,7 @@ class Handler(PatternMatchingEventHandler):
232
232
  super().__init__(ignore_patterns=ignore_patterns)
233
233
 
234
234
  def on_modified(self, event: FileModifiedEvent) -> None:
235
+ """Modify when a file is modified."""
235
236
  file = Path(str(event.src_path))
236
237
  if file.is_file():
237
238
  self.func(file)
@@ -242,9 +243,7 @@ def chdir_artifact(
242
243
  run: Run,
243
244
  artifact_path: str | None = None,
244
245
  ) -> Iterator[Path]:
245
- """
246
- Change the current working directory to the artifact directory of the
247
- given run.
246
+ """Change the current working directory to the artifact directory of the given run.
248
247
 
249
248
  This context manager changes the current working directory to the artifact
250
249
  directory of the given run. It ensures that the directory is changed back
@@ -253,6 +252,7 @@ def chdir_artifact(
253
252
  Args:
254
253
  run (Run): The run to get the artifact directory from.
255
254
  artifact_path (str | None): The artifact path.
255
+
256
256
  """
257
257
  curdir = Path.cwd()
258
258
  path = mlflow.artifacts.download_artifacts(
@@ -1,3 +1,5 @@
1
+ """Provide information about MLflow runs."""
2
+
1
3
  from __future__ import annotations
2
4
 
3
5
  from pathlib import Path
@@ -15,37 +17,44 @@ if TYPE_CHECKING:
15
17
 
16
18
 
17
19
  class RunCollectionInfo:
20
+ """Provide information about MLflow runs."""
21
+
18
22
  def __init__(self, runs: RunCollection) -> None:
19
23
  self._runs = runs
20
24
 
21
25
  @property
22
26
  def run_id(self) -> list[str]:
27
+ """Get the run ID for each run in the collection."""
23
28
  return [run.info.run_id for run in self._runs]
24
29
 
25
30
  @property
26
31
  def params(self) -> list[dict[str, str]]:
32
+ """Get the parameters for each run in the collection."""
27
33
  return [run.data.params for run in self._runs]
28
34
 
29
35
  @property
30
36
  def metrics(self) -> list[dict[str, float]]:
37
+ """Get the metrics for each run in the collection."""
31
38
  return [run.data.metrics for run in self._runs]
32
39
 
33
40
  @property
34
41
  def artifact_uri(self) -> list[str | None]:
42
+ """Get the artifact URI for each run in the collection."""
35
43
  return [run.info.artifact_uri for run in self._runs]
36
44
 
37
45
  @property
38
46
  def artifact_dir(self) -> list[Path]:
47
+ """Get the artifact directory for each run in the collection."""
39
48
  return [get_artifact_dir(run) for run in self._runs]
40
49
 
41
50
  @property
42
51
  def config(self) -> list[DictConfig]:
52
+ """Get the configuration for each run in the collection."""
43
53
  return [load_config(run) for run in self._runs]
44
54
 
45
55
 
46
56
  def get_artifact_dir(run: Run | None = None) -> Path:
47
- """
48
- Retrieve the artifact directory for the given run.
57
+ """Retrieve the artifact directory for the given run.
49
58
 
50
59
  This function uses MLflow to get the artifact directory for the given run.
51
60
 
@@ -54,6 +63,7 @@ def get_artifact_dir(run: Run | None = None) -> Path:
54
63
 
55
64
  Returns:
56
65
  The local path to the directory where the artifacts are downloaded.
66
+
57
67
  """
58
68
  if run is None:
59
69
  uri = mlflow.get_artifact_uri()
@@ -64,8 +74,7 @@ def get_artifact_dir(run: Run | None = None) -> Path:
64
74
 
65
75
 
66
76
  def get_hydra_output_dir(run: Run | None = None) -> Path:
67
- """
68
- Retrieve the Hydra output directory for the given run.
77
+ """Retrieve the Hydra output directory for the given run.
69
78
 
70
79
  This function returns the Hydra output directory. If no run is provided,
71
80
  it retrieves the output directory from the current Hydra configuration.
@@ -82,6 +91,7 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
82
91
  Raises:
83
92
  FileNotFoundError: If the Hydra configuration file is not found
84
93
  in the artifacts.
94
+
85
95
  """
86
96
  if run is None:
87
97
  hc = HydraConfig.get()
@@ -97,8 +107,7 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
97
107
 
98
108
 
99
109
  def load_config(run: Run) -> DictConfig:
100
- """
101
- Load the configuration for a given run.
110
+ """Load the configuration for a given run.
102
111
 
103
112
  This function loads the configuration for the provided Run instance
104
113
  by downloading the configuration file from the MLflow artifacts and
@@ -111,6 +120,7 @@ def load_config(run: Run) -> DictConfig:
111
120
  Returns:
112
121
  The loaded configuration as a DictConfig object. Returns an empty
113
122
  DictConfig if the configuration file is not found.
123
+
114
124
  """
115
125
  path = get_artifact_dir(run) / ".hydra/config.yaml"
116
126
  return OmegaConf.load(path) # type: ignore
@@ -1,20 +1,17 @@
1
- """
2
- This module provides functionality to log parameters from Hydra configuration objects
3
- and set up experiments using MLflow. It includes methods for managing experiments,
4
- searching for runs, and logging parameters and artifacts.
1
+ """Provide functionality to log parameters from Hydra configuration objects.
2
+
3
+ This module provides functions to log parameters from Hydra configuration objects
4
+ to MLflow, set experiments, and manage tracking URIs. It integrates Hydra's
5
+ configuration management with MLflow's experiment tracking capabilities.
5
6
 
6
7
  Key Features:
7
- - **Experiment Management**: Set and manage MLflow experiments with customizable names
8
- based on Hydra configuration.
9
- - **Run Logging**: Log parameters and metrics from Hydra configuration objects to
10
- MLflow, ensuring that all relevant information is captured during experiments.
11
- - **Run Search**: Search for runs based on various criteria, allowing for flexible
12
- retrieval of experiment results.
13
- - **Artifact Management**: Retrieve and log artifacts associated with runs, facilitating
14
- easy access to outputs generated during experiments.
15
-
16
- This module is designed to integrate seamlessly with Hydra, providing a robust
17
- solution for tracking machine learning experiments and their associated metadata.
8
+ - **Experiment Management**: Set experiment names and tracking URIs using Hydra
9
+ configuration details.
10
+ - **Parameter Logging**: Log parameters from Hydra configuration objects to MLflow,
11
+ supporting both synchronous and asynchronous logging.
12
+ - **Run Collection**: Utilize the `RunCollection` class to manage and interact with
13
+ multiple MLflow runs, providing methods to filter and retrieve runs based on
14
+ various criteria.
18
15
  """
19
16
 
20
17
  from __future__ import annotations
@@ -40,8 +37,7 @@ def set_experiment(
40
37
  suffix: str = "",
41
38
  uri: str | Path | None = None,
42
39
  ) -> Experiment:
43
- """
44
- Sets the experiment name and tracking URI optionally.
40
+ """Set the experiment name and tracking URI optionally.
45
41
 
46
42
  This function sets the experiment name by combining the given prefix,
47
43
  the job name from HydraConfig, and the given suffix. Optionally, it can
@@ -55,6 +51,7 @@ def set_experiment(
55
51
  Returns:
56
52
  Experiment: An instance of `mlflow.entities.Experiment` representing
57
53
  the new active experiment.
54
+
58
55
  """
59
56
  if uri is not None:
60
57
  mlflow.set_tracking_uri(uri)
@@ -65,8 +62,7 @@ def set_experiment(
65
62
 
66
63
 
67
64
  def log_params(config: object, *, synchronous: bool | None = None) -> None:
68
- """
69
- Log the parameters from the given configuration object.
65
+ """Log the parameters from the given configuration object.
70
66
 
71
67
  This method logs the parameters from the provided configuration object
72
68
  using MLflow. It iterates over the parameters and logs them using the
@@ -76,6 +72,7 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
76
72
  config (object): The configuration object to log the parameters from.
77
73
  synchronous (bool | None): Whether to log the parameters synchronously.
78
74
  Defaults to None.
75
+
79
76
  """
80
77
  for key, value in iter_params(config):
81
78
  mlflow.log_param(key, value, synchronous=synchronous)
@@ -91,8 +88,7 @@ def search_runs( # noqa: PLR0913
91
88
  search_all_experiments: bool = False,
92
89
  experiment_names: list[str] | None = None,
93
90
  ) -> RunCollection:
94
- """
95
- Search for Runs that fit the specified criteria.
91
+ """Search for Runs that fit the specified criteria.
96
92
 
97
93
  This function wraps the `mlflow.search_runs` function and returns the
98
94
  results as a `RunCollection` object. It allows for flexible searching of
@@ -133,6 +129,7 @@ def search_runs( # noqa: PLR0913
133
129
 
134
130
  Returns:
135
131
  A `RunCollection` object containing the search results.
132
+
136
133
  """
137
134
  runs = mlflow.search_runs(
138
135
  experiment_ids=experiment_ids,
@@ -151,9 +148,9 @@ def search_runs( # noqa: PLR0913
151
148
  def list_runs(
152
149
  experiment_names: str | list[str] | None = None,
153
150
  n_jobs: int = 0,
151
+ status: str | list[str] | int | list[int] | None = None,
154
152
  ) -> RunCollection:
155
- """
156
- List all runs for the specified experiments.
153
+ """List all runs for the specified experiments.
157
154
 
158
155
  This function retrieves all runs for the given list of experiment names.
159
156
  If no experiment names are provided (None), it defaults to searching all runs
@@ -169,11 +166,27 @@ def list_runs(
169
166
  for runs. If None or an empty list is provided, the function will
170
167
  search the currently active experiment or all experiments except
171
168
  the "Default" experiment.
169
+ n_jobs (int): The number of jobs to run in parallel. If 0, the function
170
+ will search runs sequentially.
171
+ status (str | list[str] | int | list[int] | None): The status of the runs
172
+ to filter.
172
173
 
173
174
  Returns:
174
175
  RunCollection: A `RunCollection` instance containing the runs for the
175
176
  specified experiments.
177
+
176
178
  """
179
+ rc = _list_runs(experiment_names, n_jobs)
180
+ if status is None:
181
+ return rc
182
+
183
+ return rc.filter(status=status)
184
+
185
+
186
+ def _list_runs(
187
+ experiment_names: str | list[str] | None = None,
188
+ n_jobs: int = 0,
189
+ ) -> RunCollection:
177
190
  if isinstance(experiment_names, str):
178
191
  experiment_names = [experiment_names]
179
192
 
@@ -0,0 +1,75 @@
1
+ """Provide utility functions for parameter matching.
2
+
3
+ The main function `match` checks if a given parameter matches a specified value.
4
+ It supports various types of values including None, boolean, list, tuple, int,
5
+ float, and str.
6
+
7
+ Helper functions `_match_list` and `_match_tuple` are used internally to handle
8
+ matching for list and tuple types respectively.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Any
14
+
15
+
16
+ def match(param: str, value: Any) -> bool:
17
+ """Check if the string matches the specified value.
18
+
19
+ Args:
20
+ param (str): The parameter to check.
21
+ value (Any): The value to check.
22
+
23
+ Returns:
24
+ True if the parameter matches the specified value,
25
+ False otherwise.
26
+
27
+ """
28
+ if value in [None, True, False]:
29
+ return param == str(value)
30
+
31
+ if isinstance(value, list) and (m := _match_list(param, value)) is not None:
32
+ return m
33
+
34
+ if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
35
+ return m
36
+
37
+ if isinstance(value, int | float | str):
38
+ return type(value)(param) == value
39
+
40
+ return param == str(value)
41
+
42
+
43
+ def _match_list(param: str, value: list) -> bool | None:
44
+ if not value:
45
+ return None
46
+
47
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
48
+ return None
49
+
50
+ if isinstance(value[0], bool):
51
+ return None
52
+
53
+ if not isinstance(value[0], int | float | str):
54
+ return None
55
+
56
+ return type(value[0])(param) in value
57
+
58
+
59
+ def _match_tuple(param: str, value: tuple) -> bool | None:
60
+ if len(value) != 2: # noqa: PLR2004
61
+ return None
62
+
63
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
64
+ return None
65
+
66
+ if isinstance(value[0], bool):
67
+ return None
68
+
69
+ if not isinstance(value[0], int | float | str):
70
+ return None
71
+
72
+ if type(value[0]) is not type(value[1]):
73
+ return None
74
+
75
+ return value[0] <= type(value[0])(param) < value[1] # type: ignore
@@ -1,18 +1,7 @@
1
- """
2
- Module for managing progress tracking in parallel processing using Joblib
3
- and Rich's Progress bar.
1
+ """Context managers and functions for parallel task execution with progress.
4
2
 
5
3
  Provide context managers and functions to facilitate the execution
6
4
  of tasks in parallel while displaying progress updates.
7
-
8
- The following key components are provided:
9
-
10
- - JoblibProgress: A context manager for tracking progress with Rich's progress
11
- bar.
12
- - parallel_progress: A function to execute a given function in parallel over
13
- an iterable with progress tracking.
14
- - multi_tasks_progress: A function to render auto-updating progress bars for
15
- multiple tasks concurrently.
16
5
  """
17
6
 
18
7
  from __future__ import annotations
@@ -37,8 +26,7 @@ def JoblibProgress( # noqa: N802
37
26
  total: int | None = None,
38
27
  **kwargs,
39
28
  ) -> Iterator[Progress]:
40
- """
41
- Context manager for tracking progress using Joblib with Rich's Progress bar.
29
+ """Context manager for tracking progress using Joblib with Rich's Progress bar.
42
30
 
43
31
  Args:
44
32
  *columns (ProgressColumn | str): Columns to display in the progress bar.
@@ -56,6 +44,7 @@ def JoblibProgress( # noqa: N802
56
44
  with JoblibProgress("task", total=100) as progress:
57
45
  # Your parallel processing code here
58
46
  ```
47
+
59
48
  """
60
49
  if not columns:
61
50
  columns = Progress.get_default_columns()
@@ -94,8 +83,7 @@ def parallel_progress(
94
83
  description: str | None = None,
95
84
  **kwargs,
96
85
  ) -> list[U]:
97
- """
98
- Execute a function in parallel over an iterable with progress tracking.
86
+ """Execute a function in parallel over an iterable with progress tracking.
99
87
 
100
88
  Args:
101
89
  func (Callable[[T], U]): The function to execute on each item in the
@@ -112,6 +100,7 @@ def parallel_progress(
112
100
  Returns:
113
101
  list[U]: A list of results from applying the function to each item in
114
102
  the iterable.
103
+
115
104
  """
116
105
  iterable = list(iterable)
117
106
  total = len(iterable)
@@ -130,8 +119,7 @@ def multi_tasks_progress(
130
119
  transient: bool | None = None,
131
120
  **kwargs,
132
121
  ) -> None:
133
- """
134
- Render auto-updating progress bars for multiple tasks concurrently.
122
+ """Render auto-updating progress bars for multiple tasks concurrently.
135
123
 
136
124
  Args:
137
125
  iterables (Iterable[Iterable[int | tuple[int, int]]]): A collection of
@@ -151,6 +139,7 @@ def multi_tasks_progress(
151
139
 
152
140
  Returns:
153
141
  None
142
+
154
143
  """
155
144
  if not columns:
156
145
  columns = Progress.get_default_columns()