hydraflow 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
hydraflow/__init__.py CHANGED
@@ -1,7 +1,7 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
3
  from .config import select_config, select_overrides
4
- from .context import chdir_artifact, chdir_hydra_output, log_run, start_run
4
+ from .context import chdir_artifact, log_run, start_run
5
5
  from .mlflow import list_runs, search_runs, set_experiment
6
6
  from .run_collection import RunCollection
7
7
  from .utils import (
@@ -17,7 +17,6 @@ from .utils import (
17
17
  __all__ = [
18
18
  "RunCollection",
19
19
  "chdir_artifact",
20
- "chdir_hydra_output",
21
20
  "get_artifact_dir",
22
21
  "get_artifact_path",
23
22
  "get_hydra_output_dir",
hydraflow/context.py CHANGED
@@ -13,6 +13,7 @@ import mlflow.artifacts
13
13
  from hydra.core.hydra_config import HydraConfig
14
14
 
15
15
  from hydraflow.mlflow import log_params
16
+ from hydraflow.utils import get_artifact_dir
16
17
 
17
18
  if TYPE_CHECKING:
18
19
  from collections.abc import Iterator
@@ -69,24 +70,26 @@ def log_run(
69
70
  raise
70
71
 
71
72
  finally:
72
- log_hydra(output_dir)
73
+ log_text(output_dir)
73
74
 
74
75
 
75
- def log_hydra(output_dir: Path) -> None:
76
- """Log hydra logs of the current run as artifacts.
76
+ def log_text(directory: Path, pattern: str = "*.log") -> None:
77
+ """Log text files in the given directory as artifacts.
78
+
79
+ Append the text files to the existing text file in the artifact directory.
77
80
 
78
81
  Args:
79
- output_dir (Path): The output directory of the Hydra job.
82
+ directory (Path): The directory to find the logs in.
83
+ pattern (str): The pattern to match the logs.
80
84
 
81
85
  """
82
- uri = mlflow.get_artifact_uri()
83
- artifact_dir = Path(mlflow.artifacts.download_artifacts(uri))
86
+ artifact_dir = get_artifact_dir()
84
87
 
85
- for file_hydra in output_dir.glob("*.log"):
86
- if not file_hydra.is_file():
88
+ for file in directory.glob(pattern):
89
+ if not file.is_file():
87
90
  continue
88
91
 
89
- file_artifact = artifact_dir / file_hydra.name
92
+ file_artifact = artifact_dir / file.name
90
93
  if file_artifact.exists():
91
94
  text = file_artifact.read_text()
92
95
  if not text.endswith("\n"):
@@ -94,8 +97,8 @@ def log_hydra(output_dir: Path) -> None:
94
97
  else:
95
98
  text = ""
96
99
 
97
- text += file_hydra.read_text()
98
- mlflow.log_text(text, file_hydra.name)
100
+ text += file.read_text()
101
+ mlflow.log_text(text, file.name)
99
102
 
100
103
 
101
104
  @contextmanager
@@ -174,29 +177,7 @@ def start_run( # noqa: PLR0913
174
177
 
175
178
 
176
179
  @contextmanager
177
- def chdir_hydra_output() -> Iterator[Path]:
178
- """Change the current working directory to the hydra output directory.
179
-
180
- This context manager changes the current working directory to the hydra output
181
- directory. It ensures that the directory is changed back to the original
182
- directory after the context is exited.
183
- """
184
- curdir = Path.cwd()
185
- path = HydraConfig.get().runtime.output_dir
186
-
187
- os.chdir(path)
188
- try:
189
- yield Path(path)
190
-
191
- finally:
192
- os.chdir(curdir)
193
-
194
-
195
- @contextmanager
196
- def chdir_artifact(
197
- run: Run,
198
- artifact_path: str | None = None,
199
- ) -> Iterator[Path]:
180
+ def chdir_artifact(run: Run | None = None) -> Iterator[Path]:
200
181
  """Change the current working directory to the artifact directory of the given run.
201
182
 
202
183
  This context manager changes the current working directory to the artifact
@@ -204,19 +185,16 @@ def chdir_artifact(
204
185
  to the original directory after the context is exited.
205
186
 
206
187
  Args:
207
- run (Run): The run to get the artifact directory from.
208
- artifact_path (str | None): The artifact path.
188
+ run (Run | None): The run to get the artifact directory from.
209
189
 
210
190
  """
211
191
  curdir = Path.cwd()
212
- path = mlflow.artifacts.download_artifacts(
213
- run_id=run.info.run_id,
214
- artifact_path=artifact_path,
215
- )
192
+ artifact_dir = get_artifact_dir(run)
193
+
194
+ os.chdir(artifact_dir)
216
195
 
217
- os.chdir(path)
218
196
  try:
219
- yield Path(path)
197
+ yield artifact_dir
220
198
 
221
199
  finally:
222
200
  os.chdir(curdir)
hydraflow/mlflow.py CHANGED
@@ -16,7 +16,6 @@ Key Features:
16
16
 
17
17
  from __future__ import annotations
18
18
 
19
- from pathlib import Path
20
19
  from typing import TYPE_CHECKING
21
20
 
22
21
  import joblib
@@ -28,8 +27,11 @@ from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS, _get_experiment_id
28
27
 
29
28
  from hydraflow.config import iter_params
30
29
  from hydraflow.run_collection import RunCollection
30
+ from hydraflow.utils import get_artifact_dir
31
31
 
32
32
  if TYPE_CHECKING:
33
+ from pathlib import Path
34
+
33
35
  from mlflow.entities.experiment import Experiment
34
36
 
35
37
 
@@ -211,16 +213,10 @@ def _list_runs(
211
213
 
212
214
  for name in experiment_names:
213
215
  if experiment := mlflow.get_experiment_by_name(name):
214
- loc = experiment.artifact_location
215
-
216
- if isinstance(loc, str):
217
- if loc.startswith("file:"):
218
- path = Path(mlflow.artifacts.download_artifacts(loc))
219
- elif Path(loc).is_dir():
220
- path = Path(loc)
221
- else:
222
- continue # no cov
216
+ uri = experiment.artifact_location
223
217
 
218
+ if isinstance(uri, str):
219
+ path = get_artifact_dir(uri=uri)
224
220
  run_ids.extend(file.stem for file in path.iterdir() if file.is_dir())
225
221
 
226
222
  it = (joblib.delayed(mlflow.get_run)(run_id) for run_id in run_ids)
@@ -236,7 +236,7 @@ class RunCollection:
236
236
 
237
237
  def filter(
238
238
  self,
239
- config: object | None = None,
239
+ config: object | Callable[[Run], bool] | None = None,
240
240
  *,
241
241
  override: bool = False,
242
242
  select: list[str] | None = None,
@@ -257,11 +257,13 @@ class RunCollection:
257
257
  - Membership checks for lists of values.
258
258
  - Range checks for tuples of two values (inclusive of both the lower
259
259
  and upper bound).
260
+ - Callable that takes a `Run` object and returns a boolean value.
260
261
 
261
262
  Args:
262
- config (object | None): The configuration object to filter the runs.
263
- This can be any object that provides key-value pairs through
264
- the `iter_params` function.
263
+ config (object | Callable[[Run], bool] | None): The configuration object
264
+ to filter the runs. This can be any object that provides key-value
265
+ pairs through the `iter_params` function, or a callable that
266
+ takes a `Run` object and returns a boolean value.
265
267
  override (bool): If True, override the configuration object with the
266
268
  provided key-value pairs.
267
269
  select (list[str] | None): The list of parameters to select.
@@ -711,7 +713,7 @@ def _param_matches(run: Run, key: str, value: Any) -> bool:
711
713
 
712
714
  def filter_runs(
713
715
  runs: list[Run],
714
- config: object | None = None,
716
+ config: object | Callable[[Run], bool] | None = None,
715
717
  *,
716
718
  override: bool = False,
717
719
  select: list[str] | None = None,
@@ -735,9 +737,11 @@ def filter_runs(
735
737
 
736
738
  Args:
737
739
  runs (list[Run]): The list of runs to filter.
738
- config (object | None, optional): The configuration object to filter the
739
- runs. This can be any object that provides key-value pairs through
740
- the `iter_params` function. Defaults to None.
740
+ config (object | Callable[[Run], bool] | None, optional): The
741
+ configuration object to filter the runs. This can be any object
742
+ that provides key-value pairs through the `iter_params` function.
743
+ This can also be a callable that takes a `Run` object and returns
744
+ a boolean value. Defaults to None.
741
745
  override (bool, optional): If True, filter the runs based on
742
746
  the overrides. Defaults to False.
743
747
  select (list[str] | None, optional): The list of parameters to select.
@@ -750,15 +754,19 @@ def filter_runs(
750
754
  A list of runs that match the specified configuration and key-value pairs.
751
755
 
752
756
  """
753
- if override:
754
- config = select_overrides(config)
755
- elif select:
756
- config = select_config(config, select)
757
-
758
- for key, value in chain(iter_params(config), kwargs.items()):
759
- runs = [run for run in runs if _param_matches(run, key, value)]
760
- if not runs:
761
- return []
757
+ if callable(config):
758
+ runs = [run for run in runs if config(run)]
759
+
760
+ else:
761
+ if override:
762
+ config = select_overrides(config)
763
+ elif select:
764
+ config = select_config(config, select)
765
+
766
+ for key, value in chain(iter_params(config), kwargs.items()):
767
+ runs = [run for run in runs if _param_matches(run, key, value)]
768
+ if not runs:
769
+ return []
762
770
 
763
771
  if status is None:
764
772
  return runs
hydraflow/utils.py CHANGED
@@ -3,6 +3,8 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import shutil
6
+ import urllib.parse
7
+ import urllib.request
6
8
  from pathlib import Path
7
9
  from typing import TYPE_CHECKING
8
10
 
@@ -16,30 +18,40 @@ if TYPE_CHECKING:
16
18
  from collections.abc import Iterable
17
19
 
18
20
 
19
- def get_artifact_dir(run: Run | None = None) -> Path:
21
+ def get_artifact_dir(run: Run | None = None, uri: str | None = None) -> Path:
20
22
  """Retrieve the artifact directory for the given run.
21
23
 
22
24
  This function uses MLflow to get the artifact directory for the given run.
23
25
 
24
26
  Args:
25
27
  run (Run | None): The run object. Defaults to None.
28
+ uri (str | None): The URI of the artifact. Defaults to None.
26
29
 
27
30
  Returns:
28
31
  The local path to the directory where the artifacts are downloaded.
29
32
 
30
33
  """
31
- uri = mlflow.get_artifact_uri() if run is None else run.info.artifact_uri
34
+ if run is not None and uri is not None:
35
+ raise ValueError("Cannot provide both run and uri")
36
+
37
+ if run is None and uri is None:
38
+ uri = mlflow.get_artifact_uri()
39
+ elif run:
40
+ uri = run.info.artifact_uri
32
41
 
33
42
  if not isinstance(uri, str):
34
43
  raise NotImplementedError
35
44
 
36
45
  if uri.startswith("file:"):
37
- return Path(mlflow.artifacts.download_artifacts(uri))
46
+ return file_uri_to_path(uri)
47
+
48
+ return Path(uri)
38
49
 
39
- if Path(uri).is_dir():
40
- return Path(uri)
41
50
 
42
- raise NotImplementedError
51
+ def file_uri_to_path(uri: str) -> Path:
52
+ """Convert a file URI to a local path."""
53
+ path = urllib.parse.urlparse(uri).path
54
+ return Path(urllib.request.url2pathname(path)) # for Windows
43
55
 
44
56
 
45
57
  def get_artifact_path(run: Run | None, path: str) -> Path:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hydraflow
3
- Version: 0.6.0
3
+ Version: 0.6.2
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -0,0 +1,14 @@
1
+ hydraflow/__init__.py,sha256=VPIPNNCyjMAkWBbdvB7Ltwe3QWoc2FwuqkV8uJM5JoM,809
2
+ hydraflow/config.py,sha256=MNX9da5bPVDcjnpji7Cm9ndK6ura92pt361m4PRh6_E,4326
3
+ hydraflow/context.py,sha256=3xfKhMozkKFqtWeOp9Gie0A5o5URMta4US6iVD5TcLU,6002
4
+ hydraflow/mlflow.py,sha256=imD3XL0RTlpnKrkyvO8FNy_Bv6hwSfLiOu1yJuL40ck,8773
5
+ hydraflow/param.py,sha256=yu1aMNXRLegXGDL-68vwIkfeDF9CaU784WZENGLwl7Q,4572
6
+ hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ hydraflow/run_collection.py,sha256=w_GZdc_6yviwzRWLndWDSWB4DKyGyA9di9d9UpkkLZo,27926
8
+ hydraflow/run_data.py,sha256=dpyyfnuH9mCtIZeigMo1iFQo9bafMdEL4i4uI2l0UqY,1525
9
+ hydraflow/run_info.py,sha256=Jf5wrIjRLIV1-k-obHDqwKHa6j_ZonrY8od-rXlbtMo,1024
10
+ hydraflow/utils.py,sha256=a9i5PEJn8Ssowv9dqHadAihZXlsqtVjHZ9MZvkPq1bY,4747
11
+ hydraflow-0.6.2.dist-info/METADATA,sha256=9a3blsQ91rNP1Ql4kFDc7tZxDMbdK5PzEAfP9ZyUY6A,4700
12
+ hydraflow-0.6.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ hydraflow-0.6.2.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
14
+ hydraflow-0.6.2.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- hydraflow/__init__.py,sha256=9XO9FD3uiTTPN6X6UAC9FtkJjEqUQZNqpoAmSrjUHfI,855
2
- hydraflow/config.py,sha256=MNX9da5bPVDcjnpji7Cm9ndK6ura92pt361m4PRh6_E,4326
3
- hydraflow/context.py,sha256=rc43zvE2ueki0zEzorCMIthD9cho_PkbLLJYF9WgDqY,6562
4
- hydraflow/mlflow.py,sha256=h2S_A2wElr_1lAq0D1wkoEfdtDZpPuWFNRcO8mV_VrA,8932
5
- hydraflow/param.py,sha256=yu1aMNXRLegXGDL-68vwIkfeDF9CaU784WZENGLwl7Q,4572
6
- hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- hydraflow/run_collection.py,sha256=2GRVOy87_2SPjHuCzzUvRNugO_grtFUVjtTfhznwBAc,27444
8
- hydraflow/run_data.py,sha256=dpyyfnuH9mCtIZeigMo1iFQo9bafMdEL4i4uI2l0UqY,1525
9
- hydraflow/run_info.py,sha256=Jf5wrIjRLIV1-k-obHDqwKHa6j_ZonrY8od-rXlbtMo,1024
10
- hydraflow/utils.py,sha256=oXjcyfQBbPzJNTh3_CbZfl23zgJS-mbNM9GAWBwsn8c,4349
11
- hydraflow-0.6.0.dist-info/METADATA,sha256=xUib1EsbG3Es5jFx0cSkF1QItfTuciBHYM1040GqFzA,4700
12
- hydraflow-0.6.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
- hydraflow-0.6.0.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
14
- hydraflow-0.6.0.dist-info/RECORD,,