hydraflow 0.5.4__tar.gz → 0.6.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (56) hide show
  1. {hydraflow-0.5.4 → hydraflow-0.6.1}/PKG-INFO +1 -1
  2. {hydraflow-0.5.4 → hydraflow-0.6.1}/pyproject.toml +27 -3
  3. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/__init__.py +1 -2
  4. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/config.py +3 -0
  5. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/context.py +53 -35
  6. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/mlflow.py +10 -10
  7. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/param.py +4 -0
  8. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/run_collection.py +29 -0
  9. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/run_data.py +1 -0
  10. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/utils.py +23 -6
  11. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/config/overrides.py +1 -2
  12. hydraflow-0.6.1/tests/context/chdir.py +29 -0
  13. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/context/context.py +1 -5
  14. hydraflow-0.6.1/tests/context/logging.py +40 -0
  15. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/context/preemption.py +1 -2
  16. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/context/rerun.py +3 -7
  17. hydraflow-0.6.1/tests/context/test_chdir.py +27 -0
  18. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/context/test_context.py +1 -11
  19. hydraflow-0.6.1/tests/context/test_logging.py +51 -0
  20. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/param/params.py +1 -2
  21. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/param/test_params.py +1 -1
  22. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/filter.py +1 -2
  23. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/run.py +1 -2
  24. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/test_info.py +1 -1
  25. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/test_run.py +1 -1
  26. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/utils/test_utils.py +25 -0
  27. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/utils/utils.py +1 -2
  28. {hydraflow-0.5.4 → hydraflow-0.6.1}/.devcontainer/devcontainer.json +0 -0
  29. {hydraflow-0.5.4 → hydraflow-0.6.1}/.devcontainer/postCreate.sh +0 -0
  30. {hydraflow-0.5.4 → hydraflow-0.6.1}/.devcontainer/starship.toml +0 -0
  31. {hydraflow-0.5.4 → hydraflow-0.6.1}/.gitattributes +0 -0
  32. {hydraflow-0.5.4 → hydraflow-0.6.1}/.gitignore +0 -0
  33. {hydraflow-0.5.4 → hydraflow-0.6.1}/LICENSE +0 -0
  34. {hydraflow-0.5.4 → hydraflow-0.6.1}/README.md +0 -0
  35. {hydraflow-0.5.4 → hydraflow-0.6.1}/apps/quickstart.py +0 -0
  36. {hydraflow-0.5.4 → hydraflow-0.6.1}/mkdocs.yml +0 -0
  37. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/py.typed +0 -0
  38. {hydraflow-0.5.4 → hydraflow-0.6.1}/src/hydraflow/run_info.py +0 -0
  39. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/__init__.py +0 -0
  40. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/config/__init__.py +0 -0
  41. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/config/test_config.py +0 -0
  42. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/config/test_overrides.py +0 -0
  43. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/config/test_params.py +0 -0
  44. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/conftest.py +0 -0
  45. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/context/__init__.py +0 -0
  46. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/context/test_preemption.py +0 -0
  47. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/context/test_rerun.py +0 -0
  48. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/param/__init__.py +0 -0
  49. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/param/test_param.py +0 -0
  50. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/__init__.py +0 -0
  51. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/test_collection.py +0 -0
  52. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/test_data.py +0 -0
  53. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/run/test_filter.py +0 -0
  54. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/test_mlflow.py +0 -0
  55. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/utils/__init__.py +0 -0
  56. {hydraflow-0.5.4 → hydraflow-0.6.1}/tests/utils/test_run.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hydraflow
3
- Version: 0.5.4
3
+ Version: 0.6.1
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.5.4"
7
+ version = "0.6.1"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -68,8 +68,32 @@ target-version = "py310"
68
68
  [tool.ruff.lint]
69
69
  select = ["ALL"]
70
70
  unfixable = ["F401"]
71
- ignore = ["A005", "ANN003", "ANN401", "B904", "D", "EM101", "PGH003", "TRY003"]
71
+ ignore = [
72
+ "A005",
73
+ "ANN003",
74
+ "ANN401",
75
+ "B904",
76
+ "D105",
77
+ "D107",
78
+ "D203",
79
+ "D213",
80
+ "EM101",
81
+ "PGH003",
82
+ "PLR1704",
83
+ "TRY003",
84
+ ]
72
85
 
73
86
  [tool.ruff.lint.per-file-ignores]
74
- "tests/*" = ["A001", "ANN", "ARG", "FBT", "PLR", "PT", "S", "SIM108", "SLF"]
87
+ "tests/*" = [
88
+ "A001",
89
+ "ANN",
90
+ "ARG",
91
+ "D",
92
+ "FBT",
93
+ "PLR",
94
+ "PT",
95
+ "S",
96
+ "SIM108",
97
+ "SLF",
98
+ ]
75
99
  "apps/*.py" = ["D", "G", "INP"]
@@ -1,7 +1,7 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
3
  from .config import select_config, select_overrides
4
- from .context import chdir_artifact, chdir_hydra_output, log_run, start_run
4
+ from .context import chdir_artifact, log_run, start_run
5
5
  from .mlflow import list_runs, search_runs, set_experiment
6
6
  from .run_collection import RunCollection
7
7
  from .utils import (
@@ -17,7 +17,6 @@ from .utils import (
17
17
  __all__ = [
18
18
  "RunCollection",
19
19
  "chdir_artifact",
20
- "chdir_hydra_output",
21
20
  "get_artifact_dir",
22
21
  "get_artifact_path",
23
22
  "get_hydra_output_dir",
@@ -22,6 +22,7 @@ def collect_params(config: object) -> dict[str, Any]:
22
22
 
23
23
  Returns:
24
24
  dict[str, Any]: A dictionary of collected parameters.
25
+
25
26
  """
26
27
  return dict(iter_params(config))
27
28
 
@@ -40,6 +41,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
40
41
 
41
42
  Yields:
42
43
  Key-value pairs representing the parameters in the configuration object.
44
+
43
45
  """
44
46
  if config is None:
45
47
  return
@@ -113,6 +115,7 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
113
115
 
114
116
  Returns:
115
117
  DictConfig: A new configuration object containing only the selected parameters.
118
+
116
119
  """
117
120
  if not isinstance(config, DictConfig):
118
121
  config = OmegaConf.structured(config)
@@ -13,6 +13,7 @@ import mlflow.artifacts
13
13
  from hydra.core.hydra_config import HydraConfig
14
14
 
15
15
  from hydraflow.mlflow import log_params
16
+ from hydraflow.utils import get_artifact_dir
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from collections.abc import Iterator
@@ -48,6 +49,7 @@ def log_run(
48
49
  # Perform operations within the MLflow run context
49
50
  pass
50
51
  ```
52
+
51
53
  """
52
54
  if config:
53
55
  log_params(config, synchronous=synchronous)
@@ -55,7 +57,7 @@ def log_run(
55
57
  hc = HydraConfig.get()
56
58
  output_dir = Path(hc.runtime.output_dir)
57
59
 
58
- # Save '.hydra' config directory first.
60
+ # Save '.hydra' config directory.
59
61
  output_subdir = output_dir / (hc.output_subdir or "")
60
62
  mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
61
63
 
@@ -68,14 +70,43 @@ def log_run(
68
70
  raise
69
71
 
70
72
  finally:
71
- # Save output_dir including '.hydra' config directory.
72
- mlflow.log_artifacts(output_dir.as_posix())
73
+ log_text(output_dir)
74
+
75
+
76
+ def log_text(directory: Path, pattern: str = "*.log") -> None:
77
+ """Log text files in the given directory as artifacts.
78
+
79
+ Append the text files to the existing text file in the artifact directory.
80
+
81
+ Args:
82
+ directory (Path): The directory to find the logs in.
83
+ pattern (str): The pattern to match the logs.
84
+
85
+ """
86
+ artifact_dir = get_artifact_dir()
87
+
88
+ for file in directory.glob(pattern):
89
+ if not file.is_file():
90
+ continue
91
+
92
+ file_artifact = artifact_dir / file.name
93
+ if file_artifact.exists():
94
+ text = file_artifact.read_text()
95
+ if not text.endswith("\n"):
96
+ text += "\n"
97
+ else:
98
+ text = ""
99
+
100
+ text += file.read_text()
101
+ mlflow.log_text(text, file.name)
73
102
 
74
103
 
75
104
  @contextmanager
76
105
  def start_run( # noqa: PLR0913
77
106
  config: object,
78
107
  *,
108
+ chdir: bool = False,
109
+ run: Run | None = None,
79
110
  run_id: str | None = None,
80
111
  experiment_id: str | None = None,
81
112
  run_name: str | None = None,
@@ -93,6 +124,9 @@ def start_run( # noqa: PLR0913
93
124
 
94
125
  Args:
95
126
  config (object): The configuration object to log parameters from.
127
+ chdir (bool): Whether to change the current working directory to the
128
+ artifact directory of the current run. Defaults to False.
129
+ run (Run | None): The existing run. Defaults to None.
96
130
  run_id (str | None): The existing run ID. Defaults to None.
97
131
  experiment_id (str | None): The experiment ID. Defaults to None.
98
132
  run_name (str | None): The name of the run. Defaults to None.
@@ -117,7 +151,11 @@ def start_run( # noqa: PLR0913
117
151
  - `mlflow.start_run`: The MLflow function to start a run directly.
118
152
  - `log_run`: A context manager to log parameters and manage the MLflow
119
153
  run context.
154
+
120
155
  """
156
+ if run:
157
+ run_id = run.info.run_id
158
+
121
159
  with (
122
160
  mlflow.start_run(
123
161
  run_id=run_id,
@@ -131,33 +169,15 @@ def start_run( # noqa: PLR0913
131
169
  ) as run,
132
170
  log_run(config if run_id is None else None, synchronous=synchronous),
133
171
  ):
134
- yield run
172
+ if chdir:
173
+ with chdir_artifact(run):
174
+ yield run
175
+ else:
176
+ yield run
135
177
 
136
178
 
137
179
  @contextmanager
138
- def chdir_hydra_output() -> Iterator[Path]:
139
- """Change the current working directory to the hydra output directory.
140
-
141
- This context manager changes the current working directory to the hydra output
142
- directory. It ensures that the directory is changed back to the original
143
- directory after the context is exited.
144
- """
145
- curdir = Path.cwd()
146
- path = HydraConfig.get().runtime.output_dir
147
-
148
- os.chdir(path)
149
- try:
150
- yield Path(path)
151
-
152
- finally:
153
- os.chdir(curdir)
154
-
155
-
156
- @contextmanager
157
- def chdir_artifact(
158
- run: Run,
159
- artifact_path: str | None = None,
160
- ) -> Iterator[Path]:
180
+ def chdir_artifact(run: Run | None = None) -> Iterator[Path]:
161
181
  """Change the current working directory to the artifact directory of the given run.
162
182
 
163
183
  This context manager changes the current working directory to the artifact
@@ -165,18 +185,16 @@ def chdir_artifact(
165
185
  to the original directory after the context is exited.
166
186
 
167
187
  Args:
168
- run (Run): The run to get the artifact directory from.
169
- artifact_path (str | None): The artifact path.
188
+ run (Run | None): The run to get the artifact directory from.
189
+
170
190
  """
171
191
  curdir = Path.cwd()
172
- path = mlflow.artifacts.download_artifacts(
173
- run_id=run.info.run_id,
174
- artifact_path=artifact_path,
175
- )
192
+ artifact_dir = get_artifact_dir(run)
193
+
194
+ os.chdir(artifact_dir)
176
195
 
177
- os.chdir(path)
178
196
  try:
179
- yield Path(path)
197
+ yield artifact_dir
180
198
 
181
199
  finally:
182
200
  os.chdir(curdir)
@@ -16,7 +16,6 @@ Key Features:
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- from pathlib import Path
20
19
  from typing import TYPE_CHECKING
21
20
 
22
21
  import joblib
@@ -28,8 +27,11 @@ from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS, _get_experiment_id
28
27
 
29
28
  from hydraflow.config import iter_params
30
29
  from hydraflow.run_collection import RunCollection
30
+ from hydraflow.utils import get_artifact_dir
31
31
 
32
32
  if TYPE_CHECKING:
33
+ from pathlib import Path
34
+
33
35
  from mlflow.entities.experiment import Experiment
34
36
 
35
37
 
@@ -54,6 +56,7 @@ def set_experiment(
54
56
  Returns:
55
57
  Experiment: An instance of `mlflow.entities.Experiment` representing
56
58
  the new active experiment.
59
+
57
60
  """
58
61
  if uri is not None:
59
62
  mlflow.set_tracking_uri(uri)
@@ -77,6 +80,7 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
77
80
  config (object): The configuration object to log the parameters from.
78
81
  synchronous (bool | None): Whether to log the parameters synchronously.
79
82
  Defaults to None.
83
+
80
84
  """
81
85
  for key, value in iter_params(config):
82
86
  mlflow.log_param(key, value, synchronous=synchronous)
@@ -133,6 +137,7 @@ def search_runs( # noqa: PLR0913
133
137
 
134
138
  Returns:
135
139
  A `RunCollection` object containing the search results.
140
+
136
141
  """
137
142
  runs = mlflow.search_runs(
138
143
  experiment_ids=experiment_ids,
@@ -177,6 +182,7 @@ def list_runs(
177
182
  Returns:
178
183
  RunCollection: A `RunCollection` instance containing the runs for the
179
184
  specified experiments.
185
+
180
186
  """
181
187
  rc = _list_runs(experiment_names, n_jobs)
182
188
  if status is None:
@@ -207,16 +213,10 @@ def _list_runs(
207
213
 
208
214
  for name in experiment_names:
209
215
  if experiment := mlflow.get_experiment_by_name(name):
210
- loc = experiment.artifact_location
211
-
212
- if isinstance(loc, str):
213
- if loc.startswith("file:"):
214
- path = Path(mlflow.artifacts.download_artifacts(loc))
215
- elif Path(loc).is_dir():
216
- path = Path(loc)
217
- else:
218
- continue # no cov
216
+ uri = experiment.artifact_location
219
217
 
218
+ if isinstance(uri, str):
219
+ path = get_artifact_dir(uri=uri)
220
220
  run_ids.extend(file.stem for file in path.iterdir() if file.is_dir())
221
221
 
222
222
  it = (joblib.delayed(mlflow.get_run)(run_id) for run_id in run_ids)
@@ -28,6 +28,7 @@ def match(param: str, value: Any) -> bool: # noqa: PLR0911
28
28
  Returns:
29
29
  True if the parameter matches the specified value,
30
30
  False otherwise.
31
+
31
32
  """
32
33
  if callable(value):
33
34
  return value(param)
@@ -94,6 +95,7 @@ def to_value(param: str | None, type_: type) -> Any:
94
95
 
95
96
  Returns:
96
97
  The converted value.
98
+
97
99
  """
98
100
  if param is None or param == "None":
99
101
  return None
@@ -129,6 +131,7 @@ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
129
131
  Returns:
130
132
  tuple[str | None, ...]: A tuple containing the values of the specified
131
133
  parameters in the order they were provided.
134
+
132
135
  """
133
136
  names_ = []
134
137
  for name in names:
@@ -155,6 +158,7 @@ def get_values(run: Run, names: list[str], types: list[type]) -> tuple[Any, ...]
155
158
  Returns:
156
159
  tuple[Any, ...]: A tuple containing the values of the specified
157
160
  parameters in the order they were provided.
161
+
158
162
  """
159
163
  params = get_params(run, names)
160
164
  it = zip(params, types, strict=True)
@@ -106,6 +106,7 @@ class RunCollection:
106
106
 
107
107
  Returns:
108
108
  A new `RunCollection` instance with the runs from both collections.
109
+
109
110
  """
110
111
  return self.__class__(self._runs + other._runs)
111
112
 
@@ -118,6 +119,7 @@ class RunCollection:
118
119
  Returns:
119
120
  A new `RunCollection` instance with the runs that are in this collection
120
121
  but not in the other.
122
+
121
123
  """
122
124
  runs = [run for run in self._runs if run not in other._runs] # noqa: SLF001
123
125
  return self.__class__(runs)
@@ -150,6 +152,7 @@ class RunCollection:
150
152
  Returns:
151
153
  A new `RunCollection` instance containing the first n runs if n is
152
154
  positive, or the last n runs if n is negative.
155
+
153
156
  """
154
157
  if n < 0:
155
158
  return self.__class__(self._runs[n:])
@@ -164,6 +167,7 @@ class RunCollection:
164
167
 
165
168
  Raises:
166
169
  ValueError: If the collection does not contain exactly one run.
170
+
167
171
  """
168
172
  if len(self._runs) != 1:
169
173
  raise ValueError("The collection does not contain exactly one run.")
@@ -176,6 +180,7 @@ class RunCollection:
176
180
  Returns:
177
181
  The only `Run` instance in the collection, or None if the collection
178
182
  does not contain exactly one run.
183
+
179
184
  """
180
185
  return self._runs[0] if len(self._runs) == 1 else None
181
186
 
@@ -187,6 +192,7 @@ class RunCollection:
187
192
 
188
193
  Raises:
189
194
  ValueError: If the collection is empty.
195
+
190
196
  """
191
197
  if not self._runs:
192
198
  raise ValueError("The collection is empty.")
@@ -199,6 +205,7 @@ class RunCollection:
199
205
  Returns:
200
206
  The first `Run` instance in the collection, or None if the collection
201
207
  is empty.
208
+
202
209
  """
203
210
  return self._runs[0] if self._runs else None
204
211
 
@@ -210,6 +217,7 @@ class RunCollection:
210
217
 
211
218
  Raises:
212
219
  ValueError: If the collection is empty.
220
+
213
221
  """
214
222
  if not self._runs:
215
223
  raise ValueError("The collection is empty.")
@@ -222,6 +230,7 @@ class RunCollection:
222
230
  Returns:
223
231
  The last `Run` instance in the collection, or None if the collection
224
232
  is empty.
233
+
225
234
  """
226
235
  return self._runs[-1] if self._runs else None
227
236
 
@@ -262,6 +271,7 @@ class RunCollection:
262
271
 
263
272
  Returns:
264
273
  A new `RunCollection` object containing the filtered runs.
274
+
265
275
  """
266
276
  return RunCollection(
267
277
  filter_runs(
@@ -294,6 +304,7 @@ class RunCollection:
294
304
 
295
305
  See Also:
296
306
  `filter`: Perform the actual filtering logic.
307
+
297
308
  """
298
309
  try:
299
310
  return self.filter(config, **kwargs).first()
@@ -318,6 +329,7 @@ class RunCollection:
318
329
 
319
330
  See Also:
320
331
  `filter`: Perform the actual filtering logic.
332
+
321
333
  """
322
334
  return self.filter(config, **kwargs).try_first()
323
335
 
@@ -341,6 +353,7 @@ class RunCollection:
341
353
 
342
354
  See Also:
343
355
  `filter`: Perform the actual filtering logic.
356
+
344
357
  """
345
358
  try:
346
359
  return self.filter(config, **kwargs).last()
@@ -365,6 +378,7 @@ class RunCollection:
365
378
 
366
379
  See Also:
367
380
  `filter`: Perform the actual filtering logic.
381
+
368
382
  """
369
383
  return self.filter(config, **kwargs).try_last()
370
384
 
@@ -389,6 +403,7 @@ class RunCollection:
389
403
 
390
404
  See Also:
391
405
  `filter`: Perform the actual filtering logic.
406
+
392
407
  """
393
408
  try:
394
409
  return self.filter(config, **kwargs).one()
@@ -417,6 +432,7 @@ class RunCollection:
417
432
 
418
433
  See Also:
419
434
  `filter`: Perform the actual filtering logic.
435
+
420
436
  """
421
437
  return self.filter(config, **kwargs).try_one()
422
438
 
@@ -429,6 +445,7 @@ class RunCollection:
429
445
 
430
446
  Returns:
431
447
  A list of unique parameter names.
448
+
432
449
  """
433
450
  param_names = set()
434
451
 
@@ -453,6 +470,7 @@ class RunCollection:
453
470
  Returns:
454
471
  A dictionary where the keys are parameter names and the values are
455
472
  lists of parameter values.
473
+
456
474
  """
457
475
  params = {}
458
476
 
@@ -484,6 +502,7 @@ class RunCollection:
484
502
 
485
503
  Yields:
486
504
  Results obtained by applying the function to each run in the collection.
505
+
487
506
  """
488
507
  return (func(run, *args, **kwargs) for run in self)
489
508
 
@@ -504,6 +523,7 @@ class RunCollection:
504
523
  Yields:
505
524
  Results obtained by applying the function to each run id in the
506
525
  collection.
526
+
507
527
  """
508
528
  return (func(run_id, *args, **kwargs) for run_id in self.info.run_id)
509
529
 
@@ -524,6 +544,7 @@ class RunCollection:
524
544
  Yields:
525
545
  Results obtained by applying the function to each run configuration
526
546
  in the collection.
547
+
527
548
  """
528
549
  return (func(load_config(run), *args, **kwargs) for run in self)
529
550
 
@@ -548,6 +569,7 @@ class RunCollection:
548
569
  Yields:
549
570
  Results obtained by applying the function to each artifact URI in the
550
571
  collection.
572
+
551
573
  """
552
574
  return (func(uri, *args, **kwargs) for uri in self.info.artifact_uri)
553
575
 
@@ -571,6 +593,7 @@ class RunCollection:
571
593
  Yields:
572
594
  Results obtained by applying the function to each artifact directory
573
595
  in the collection.
596
+
574
597
  """
575
598
  return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir) # noqa: A001
576
599
 
@@ -594,6 +617,7 @@ class RunCollection:
594
617
  dictionary where the keys are tuples of parameter values and the
595
618
  values are `RunCollection` objects containing the runs that match
596
619
  those parameter values.
620
+
597
621
  """
598
622
  grouped_runs: dict[str | None | tuple[str | None, ...], list[Run]] = {}
599
623
  is_list = isinstance(names, list)
@@ -620,6 +644,7 @@ class RunCollection:
620
644
  key (Callable[[Run], Any] | None): A function that takes a run and returns
621
645
  a value to sort by.
622
646
  reverse (bool): If True, sort in descending order.
647
+
623
648
  """
624
649
  self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
625
650
 
@@ -633,6 +658,7 @@ class RunCollection:
633
658
 
634
659
  Returns:
635
660
  A list of values for the specified parameters.
661
+
636
662
  """
637
663
  is_list = isinstance(names, list)
638
664
 
@@ -664,6 +690,7 @@ class RunCollection:
664
690
  This can be a single parameter name or multiple names provided
665
691
  as separate arguments or as a list.
666
692
  reverse (bool): If True, sort in descending order.
693
+
667
694
  """
668
695
  values = self.values(names)
669
696
  index = sorted(range(len(self)), key=lambda i: values[i], reverse=reverse)
@@ -721,6 +748,7 @@ def filter_runs(
721
748
 
722
749
  Returns:
723
750
  A list of runs that match the specified configuration and key-value pairs.
751
+
724
752
  """
725
753
  if override:
726
754
  config = select_overrides(config)
@@ -751,6 +779,7 @@ def filter_runs_by_status(
751
779
 
752
780
  Returns:
753
781
  A list of runs that match the specified status.
782
+
754
783
  """
755
784
  if isinstance(status, str):
756
785
  if status.startswith("!"):
@@ -37,6 +37,7 @@ class RunCollectionData:
37
37
 
38
38
  Returns:
39
39
  A DataFrame containing the runs' configurations.
40
+
40
41
  """
41
42
  return DataFrame(self._runs.map_config(collect_params))
42
43
 
@@ -3,6 +3,8 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import shutil
6
+ import urllib.parse
7
+ import urllib.request
6
8
  from pathlib import Path
7
9
  from typing import TYPE_CHECKING
8
10
 
@@ -16,29 +18,40 @@ if TYPE_CHECKING:
16
18
  from collections.abc import Iterable
17
19
 
18
20
 
19
- def get_artifact_dir(run: Run | None = None) -> Path:
21
+ def get_artifact_dir(run: Run | None = None, uri: str | None = None) -> Path:
20
22
  """Retrieve the artifact directory for the given run.
21
23
 
22
24
  This function uses MLflow to get the artifact directory for the given run.
23
25
 
24
26
  Args:
25
27
  run (Run | None): The run object. Defaults to None.
28
+ uri (str | None): The URI of the artifact. Defaults to None.
26
29
 
27
30
  Returns:
28
31
  The local path to the directory where the artifacts are downloaded.
32
+
29
33
  """
30
- uri = mlflow.get_artifact_uri() if run is None else run.info.artifact_uri
34
+ if run is not None and uri is not None:
35
+ raise ValueError("Cannot provide both run and uri")
36
+
37
+ if run is None and uri is None:
38
+ uri = mlflow.get_artifact_uri()
39
+ elif run:
40
+ uri = run.info.artifact_uri
31
41
 
32
42
  if not isinstance(uri, str):
33
43
  raise NotImplementedError
34
44
 
35
45
  if uri.startswith("file:"):
36
- return Path(mlflow.artifacts.download_artifacts(uri))
46
+ return file_uri_to_path(uri)
37
47
 
38
- if Path(uri).is_dir():
39
- return Path(uri)
48
+ return Path(uri)
40
49
 
41
- raise NotImplementedError
50
+
51
+ def file_uri_to_path(uri: str) -> Path:
52
+ """Convert a file URI to a local path."""
53
+ path = urllib.parse.urlparse(uri).path
54
+ return Path(urllib.request.url2pathname(path)) # for Windows
42
55
 
43
56
 
44
57
  def get_artifact_path(run: Run | None, path: str) -> Path:
@@ -52,6 +65,7 @@ def get_artifact_path(run: Run | None, path: str) -> Path:
52
65
 
53
66
  Returns:
54
67
  The local path to the artifact.
68
+
55
69
  """
56
70
  return get_artifact_dir(run) / path
57
71
 
@@ -74,6 +88,7 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
74
88
  Raises:
75
89
  FileNotFoundError: If the Hydra configuration file is not found
76
90
  in the artifacts.
91
+
77
92
  """
78
93
  if run is None:
79
94
  hc = HydraConfig.get()
@@ -102,6 +117,7 @@ def load_config(run: Run) -> DictConfig:
102
117
  Returns:
103
118
  The loaded configuration as a DictConfig object. Returns an empty
104
119
  DictConfig if the configuration file is not found.
120
+
105
121
  """
106
122
  path = get_artifact_dir(run) / ".hydra/config.yaml"
107
123
  return OmegaConf.load(path) # type: ignore
@@ -126,6 +142,7 @@ def load_overrides(run: Run) -> list[str]:
126
142
  Returns:
127
143
  The loaded overrides as a list of strings. Returns an empty list
128
144
  if the overrides file is not found.
145
+
129
146
  """
130
147
  path = get_artifact_dir(run) / ".hydra/overrides.yaml"
131
148
  return [str(x) for x in OmegaConf.load(path)]
@@ -16,8 +16,7 @@ class Config:
16
16
  height: float = 1.7
17
17
 
18
18
 
19
- cs = ConfigStore.instance()
20
- cs.store(name="config", node=Config)
19
+ ConfigStore.instance().store(name="config", node=Config)
21
20
 
22
21
 
23
22
  @hydra.main(version_base=None, config_name="config")
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import hydra
7
+ from hydra.core.config_store import ConfigStore
8
+
9
+ import hydraflow
10
+
11
+
12
+ @dataclass
13
+ class Config:
14
+ count: int = 0
15
+
16
+
17
+ ConfigStore.instance().store(name="config", node=Config)
18
+
19
+
20
+ @hydra.main(version_base=None, config_name="config")
21
+ def app(cfg: Config):
22
+ hydraflow.set_experiment()
23
+
24
+ with hydraflow.start_run(cfg, chdir=True):
25
+ Path("a.txt").write_text(str(cfg.count))
26
+
27
+
28
+ if __name__ == "__main__":
29
+ app()
@@ -15,8 +15,7 @@ class Config:
15
15
  name: str = "a"
16
16
 
17
17
 
18
- cs = ConfigStore.instance()
19
- cs.store(name="config", node=Config)
18
+ ConfigStore.instance().store(name="config", node=Config)
20
19
 
21
20
 
22
21
  @hydra.main(version_base=None, config_name="config")
@@ -24,9 +23,6 @@ def app(cfg: Config):
24
23
  hydraflow.set_experiment()
25
24
 
26
25
  with hydraflow.start_run(cfg) as run:
27
- with hydraflow.chdir_hydra_output():
28
- Path("a.txt").write_text("chdir_hydra_output")
29
-
30
26
  with hydraflow.chdir_artifact(run):
31
27
  Path("b.txt").write_text("chdir_artifact")
32
28
 
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+
7
+ import hydra
8
+ from hydra.core.config_store import ConfigStore
9
+ from hydra.core.hydra_config import HydraConfig
10
+
11
+ import hydraflow
12
+
13
+ log = logging.getLogger(__name__)
14
+
15
+
16
+ @dataclass
17
+ class Config:
18
+ count: int = 0
19
+
20
+
21
+ ConfigStore.instance().store(name="config", node=Config)
22
+
23
+
24
+ @hydra.main(version_base=None, config_name="config")
25
+ def app(cfg: Config):
26
+ hydraflow.set_experiment()
27
+
28
+ run = hydraflow.list_runs().try_get(cfg, override=True)
29
+
30
+ with hydraflow.start_run(cfg, run=run):
31
+ log.info("second" if run else "first")
32
+ log.info(cfg.count)
33
+
34
+ output_dir = Path(HydraConfig.get().runtime.output_dir)
35
+ (output_dir / "text.log").write_text("text\n")
36
+ (output_dir / "dir.log").mkdir()
37
+
38
+
39
+ if __name__ == "__main__":
40
+ app()
@@ -17,8 +17,7 @@ class Config:
17
17
  count: int = 0
18
18
 
19
19
 
20
- cs = ConfigStore.instance()
21
- cs.store(name="config", node=Config)
20
+ ConfigStore.instance().store(name="config", node=Config)
22
21
 
23
22
 
24
23
  @hydra.main(version_base=None, config_name="config")
@@ -17,20 +17,16 @@ class Config:
17
17
  count: int = 0
18
18
 
19
19
 
20
- cs = ConfigStore.instance()
21
- cs.store(name="config", node=Config)
20
+ ConfigStore.instance().store(name="config", node=Config)
22
21
 
23
22
 
24
23
  @hydra.main(version_base=None, config_name="config")
25
24
  def app(cfg: Config):
26
25
  hydraflow.set_experiment()
27
26
 
28
- if run := hydraflow.list_runs().try_find(cfg, override=True):
29
- run_id = run.info.run_id
30
- else:
31
- run_id = None
27
+ run = hydraflow.list_runs().try_find(cfg, override=True)
32
28
 
33
- with hydraflow.start_run(cfg, run_id=run_id) as run:
29
+ with hydraflow.start_run(cfg, run=run) as run:
34
30
  log(hydraflow.get_artifact_dir(run))
35
31
 
36
32
 
@@ -0,0 +1,27 @@
1
+ import pytest
2
+ from mlflow.entities import Run
3
+
4
+ from hydraflow.run_collection import RunCollection
5
+
6
+ pytestmark = pytest.mark.xdist_group(name="group4")
7
+
8
+
9
+ @pytest.fixture(scope="module")
10
+ def rc(collect):
11
+ return collect("context/chdir.py", ["-m", "count=1,2"])
12
+
13
+
14
+ def test_rc_len(rc: RunCollection):
15
+ assert len(rc) == 2
16
+
17
+
18
+ @pytest.fixture(scope="module", params=[1, 2])
19
+ def run(rc: RunCollection, request: pytest.FixtureRequest):
20
+ return rc.get(count=request.param)
21
+
22
+
23
+ def test_run_count(run: Run):
24
+ from hydraflow.utils import get_artifact_path
25
+
26
+ text = get_artifact_path(run, "a.txt").read_text()
27
+ assert text == run.data.params["count"]
@@ -2,7 +2,7 @@ import pytest
2
2
  from mlflow.entities import Run
3
3
 
4
4
  from hydraflow.run_collection import RunCollection
5
- from hydraflow.utils import get_artifact_path, get_hydra_output_dir
5
+ from hydraflow.utils import get_artifact_path
6
6
 
7
7
  pytestmark = pytest.mark.xdist_group(name="group2")
8
8
 
@@ -18,16 +18,6 @@ def run(rc: RunCollection, request: pytest.FixtureRequest):
18
18
  return rc[request.param]
19
19
 
20
20
 
21
- def test_chdir_hydra_output(run: Run):
22
- path = get_hydra_output_dir(run)
23
- assert (path / "a.txt").read_text() == "chdir_hydra_output"
24
-
25
-
26
21
  def test_chdir_artifact(run: Run):
27
22
  path = get_artifact_path(run, "b.txt")
28
23
  assert path.read_text() == "chdir_artifact"
29
-
30
-
31
- def test_log_run(run: Run):
32
- path = get_artifact_path(run, "a.txt")
33
- assert path.read_text() == "chdir_hydra_output"
@@ -0,0 +1,51 @@
1
+ import pytest
2
+ from mlflow.entities import Run
3
+
4
+ from hydraflow.run_collection import RunCollection
5
+ from hydraflow.utils import get_artifact_path
6
+
7
+ pytestmark = pytest.mark.xdist_group(name="group6")
8
+
9
+
10
+ @pytest.fixture(scope="module")
11
+ def rc(collect):
12
+ collect("context/logging.py", ["count=100"])
13
+ return collect("context/logging.py", ["count=100"])
14
+
15
+
16
+ def test_rc_len(rc: RunCollection):
17
+ assert len(rc) == 1
18
+
19
+
20
+ @pytest.fixture(scope="module")
21
+ def run(rc: RunCollection):
22
+ return rc[0]
23
+
24
+
25
+ @pytest.fixture(scope="module")
26
+ def hydra_log(run: Run, experiment_name: str):
27
+ path = get_artifact_path(run, f"{experiment_name}.log")
28
+ return path.read_text()
29
+
30
+
31
+ @pytest.mark.parametrize(
32
+ ("i", "suffix"),
33
+ [(0, "] - first"), (1, "] - 100"), (2, "] - second"), (3, "] - 100")],
34
+ )
35
+ def test_hydra_log(hydra_log: str, i: int, suffix: str):
36
+ assert hydra_log.splitlines()[i].endswith(suffix)
37
+
38
+
39
+ def test_text_log(run: Run):
40
+ path = get_artifact_path(run, "text.log")
41
+ assert path.read_text() == "text\ntext\n"
42
+
43
+
44
+ def test_dir_log(run: Run):
45
+ assert not get_artifact_path(run, "dir.log").exists()
46
+
47
+
48
+ def test_config(run: Run):
49
+ path = get_artifact_path(run, ".hydra/config.yaml")
50
+ cfg = path.read_text()
51
+ assert cfg == "count: 100\n"
@@ -21,8 +21,7 @@ class Config:
21
21
  data: Data = field(default_factory=Data)
22
22
 
23
23
 
24
- cs = ConfigStore.instance()
25
- cs.store(name="config", node=Config)
24
+ ConfigStore.instance().store(name="config", node=Config)
26
25
 
27
26
 
28
27
  @hydra.main(version_base=None, config_name="config")
@@ -3,7 +3,7 @@ from mlflow.entities import Run
3
3
 
4
4
  from hydraflow.run_collection import RunCollection
5
5
 
6
- pytestmark = pytest.mark.xdist_group(name="group2")
6
+ pytestmark = pytest.mark.xdist_group(name="group1")
7
7
 
8
8
 
9
9
  @pytest.fixture(scope="module")
@@ -14,8 +14,7 @@ class Config:
14
14
  port: int = 3306
15
15
 
16
16
 
17
- cs = ConfigStore.instance()
18
- cs.store(name="config", node=Config)
17
+ ConfigStore.instance().store(name="config", node=Config)
19
18
 
20
19
 
21
20
  @hydra.main(version_base=None, config_name="config")
@@ -21,8 +21,7 @@ class Config:
21
21
  data: Data = field(default_factory=Data)
22
22
 
23
23
 
24
- cs = ConfigStore.instance()
25
- cs.store(name="config", node=Config)
24
+ ConfigStore.instance().store(name="config", node=Config)
26
25
 
27
26
 
28
27
  @hydra.main(version_base=None, config_name="config")
@@ -4,7 +4,7 @@ from mlflow.entities import Experiment
4
4
 
5
5
  from hydraflow.run_collection import RunCollection
6
6
 
7
- pytestmark = pytest.mark.xdist_group(name="group4")
7
+ pytestmark = pytest.mark.xdist_group(name="group7")
8
8
 
9
9
 
10
10
  @pytest.fixture(scope="module")
@@ -7,7 +7,7 @@ from hydraflow.run_collection import RunCollection
7
7
  if TYPE_CHECKING:
8
8
  from .run import Config
9
9
 
10
- pytestmark = pytest.mark.xdist_group(name="group4")
10
+ pytestmark = pytest.mark.xdist_group(name="group5")
11
11
 
12
12
 
13
13
  @pytest.fixture(scope="module")
@@ -1,3 +1,4 @@
1
+ import sys
1
2
  from typing import TYPE_CHECKING
2
3
 
3
4
  import pytest
@@ -27,6 +28,30 @@ def run(rc: RunCollection):
27
28
  return rc.first()
28
29
 
29
30
 
31
+ @pytest.mark.parametrize(
32
+ ("uri", "path"),
33
+ [("/a/b/c", "/a/b/c"), ("file:///a/b/c", "/a/b/c"), ("file:C:/a/b/c", "C:/a/b/c")],
34
+ )
35
+ def test_file_uri_to_path(uri, path):
36
+ from hydraflow.utils import file_uri_to_path
37
+
38
+ assert file_uri_to_path(uri).as_posix() == path
39
+
40
+
41
+ @pytest.mark.skipif(sys.platform != "win32", reason="This test is for Windows")
42
+ def test_file_uri_to_path_win10_11():
43
+ from hydraflow.utils import file_uri_to_path
44
+
45
+ assert file_uri_to_path("file:///C:/a/b/c").as_posix() == "C:/a/b/c"
46
+
47
+
48
+ def test_artifact_dir_error(run: Run):
49
+ from hydraflow.utils import get_artifact_dir
50
+
51
+ with pytest.raises(ValueError):
52
+ get_artifact_dir(run, "a")
53
+
54
+
30
55
  def test_hydra_output_dir(run: Run):
31
56
  from hydraflow.utils import get_artifact_path, get_hydra_output_dir
32
57
 
@@ -16,8 +16,7 @@ class Config:
16
16
  height: float = 1.7
17
17
 
18
18
 
19
- cs = ConfigStore.instance()
20
- cs.store(name="config", node=Config)
19
+ ConfigStore.instance().store(name="config", node=Config)
21
20
 
22
21
 
23
22
  @hydra.main(version_base=None, config_name="config")
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes