hydraflow 0.7.5__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hydraflow/__init__.py CHANGED
@@ -1,25 +1,16 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
- from hydraflow.config import select_config, select_overrides
4
- from hydraflow.context import chdir_artifact, log_run, start_run
5
- from hydraflow.main import main
6
- from hydraflow.mlflow import (
7
- list_run_ids,
8
- list_run_paths,
9
- list_runs,
10
- search_runs,
11
- set_experiment,
12
- )
13
- from hydraflow.run_collection import RunCollection
14
- from hydraflow.utils import (
3
+ from hydraflow.core.context import chdir_artifact, log_run, start_run
4
+ from hydraflow.core.io import (
15
5
  get_artifact_dir,
16
6
  get_artifact_path,
17
7
  get_hydra_output_dir,
18
- get_overrides,
19
8
  load_config,
20
- load_overrides,
21
9
  remove_run,
22
10
  )
11
+ from hydraflow.core.main import main
12
+ from hydraflow.core.mlflow import list_run_ids, list_run_paths, list_runs
13
+ from hydraflow.entities.run_collection import RunCollection
23
14
 
24
15
  __all__ = [
25
16
  "RunCollection",
@@ -27,18 +18,12 @@ __all__ = [
27
18
  "get_artifact_dir",
28
19
  "get_artifact_path",
29
20
  "get_hydra_output_dir",
30
- "get_overrides",
31
21
  "list_run_ids",
32
22
  "list_run_paths",
33
23
  "list_runs",
34
24
  "load_config",
35
- "load_overrides",
36
25
  "log_run",
37
26
  "main",
38
27
  "remove_run",
39
- "search_runs",
40
- "select_config",
41
- "select_overrides",
42
- "set_experiment",
43
28
  "start_run",
44
29
  ]
hydraflow/cli.py CHANGED
@@ -2,41 +2,54 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from pathlib import Path
6
- from typing import Annotated
5
+ from typing import TYPE_CHECKING, Annotated
7
6
 
8
7
  import typer
9
- from omegaconf import DictConfig, OmegaConf
10
8
  from rich.console import Console
11
9
  from typer import Argument, Option
12
10
 
11
+ from hydraflow.executor.io import load_config
12
+
13
+ if TYPE_CHECKING:
14
+ from hydraflow.executor.job import Job
15
+
13
16
  app = typer.Typer(add_completion=False)
14
17
  console = Console()
15
18
 
16
19
 
20
+ def get_job(name: str) -> Job:
21
+ cfg = load_config()
22
+ job = cfg.jobs[name]
23
+
24
+ if not job.name:
25
+ job.name = name
26
+
27
+ return job
28
+
29
+
17
30
  @app.command()
18
31
  def run(
19
- names: Annotated[
20
- list[str] | None,
21
- Argument(help="Job names.", show_default=False),
22
- ] = None,
32
+ name: Annotated[str, Argument(help="Job name.", show_default=False)],
23
33
  ) -> None:
24
- """Run jobs."""
25
- typer.echo(names)
34
+ """Run a job."""
35
+ import mlflow
26
36
 
27
- cfg = load_config()
28
- typer.echo(cfg)
37
+ from hydraflow.executor.job import multirun
38
+
39
+ job = get_job(name)
40
+ mlflow.set_experiment(job.name)
41
+ multirun(job)
29
42
 
30
43
 
31
44
  @app.command()
32
- def show() -> None:
33
- """Show the config."""
34
- from rich.syntax import Syntax
45
+ def show(
46
+ name: Annotated[str, Argument(help="Job name.", show_default=False)],
47
+ ) -> None:
48
+ """Show a job."""
49
+ from hydraflow.executor.job import show
35
50
 
36
- cfg = load_config()
37
- code = OmegaConf.to_yaml(cfg)
38
- syntax = Syntax(code, "yaml")
39
- console.print(syntax)
51
+ job = get_job(name)
52
+ show(job)
40
53
 
41
54
 
42
55
  @app.callback(invoke_without_command=True)
@@ -52,24 +65,3 @@ def callback(
52
65
 
53
66
  typer.echo(f"hydraflow {importlib.metadata.version('hydraflow')}")
54
67
  raise typer.Exit
55
-
56
-
57
- def find_config() -> Path:
58
- if Path("hydraflow.yaml").exists():
59
- return Path("hydraflow.yaml")
60
-
61
- if Path("hydraflow.yml").exists():
62
- return Path("hydraflow.yml")
63
-
64
- typer.echo("No config file found.")
65
- raise typer.Exit(code=1)
66
-
67
-
68
- def load_config() -> DictConfig:
69
- cfg = OmegaConf.load(find_config())
70
-
71
- if isinstance(cfg, DictConfig):
72
- return cfg
73
-
74
- typer.echo("Invalid config file.")
75
- raise typer.Exit(code=1)
File without changes
@@ -6,35 +6,19 @@ from typing import TYPE_CHECKING
6
6
 
7
7
  from omegaconf import DictConfig, ListConfig, OmegaConf
8
8
 
9
- from hydraflow.utils import get_overrides
10
-
11
9
  if TYPE_CHECKING:
12
10
  from collections.abc import Iterator
13
11
  from typing import Any
14
12
 
15
13
 
16
- def collect_params(config: object) -> dict[str, Any]:
17
- """Iterate over parameters and collect them into a dictionary.
18
-
19
- Args:
20
- config (object): The configuration object to iterate over.
21
- prefix (str): The prefix to prepend to the parameter keys.
22
-
23
- Returns:
24
- dict[str, Any]: A dictionary of collected parameters.
25
-
26
- """
27
- return dict(iter_params(config))
28
-
29
-
30
- def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
14
+ def iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
31
15
  """Recursively iterate over the parameters in the given configuration object.
32
16
 
33
17
  This function traverses the configuration object and yields key-value pairs
34
18
  representing the parameters. The keys are prefixed with the provided prefix.
35
19
 
36
20
  Args:
37
- config (object): The configuration object to iterate over. This can be a
21
+ config (Any): The configuration object to iterate over. This can be a
38
22
  dictionary, list, DictConfig, or ListConfig.
39
23
  prefix (str): The prefix to prepend to the parameter keys.
40
24
  Defaults to an empty string.
@@ -50,7 +34,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
50
34
  config = _from_dotlist(config)
51
35
 
52
36
  if not isinstance(config, DictConfig | ListConfig):
53
- config = OmegaConf.create(config) # type: ignore
37
+ config = OmegaConf.create(config)
54
38
 
55
39
  yield from _iter_params(config, prefix)
56
40
 
@@ -65,7 +49,7 @@ def _from_dotlist(config: list[str]) -> dict[str, str]:
65
49
  return result
66
50
 
67
51
 
68
- def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
52
+ def _iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
69
53
  if isinstance(config, DictConfig):
70
54
  for key, value in config.items():
71
55
  if _is_param(value):
@@ -83,12 +67,12 @@ def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
83
67
  yield from _iter_params(value, f"{prefix}{index}.")
84
68
 
85
69
 
86
- def _is_param(value: object) -> bool:
70
+ def _is_param(value: Any) -> bool:
87
71
  """Check if the given value is a parameter."""
88
72
  if isinstance(value, DictConfig):
89
73
  return False
90
74
 
91
- if isinstance(value, ListConfig): # noqa: SIM102
75
+ if isinstance(value, ListConfig):
92
76
  if any(isinstance(v, DictConfig | ListConfig) for v in value):
93
77
  return False
94
78
 
@@ -103,14 +87,14 @@ def _convert(value: Any) -> Any:
103
87
  return value
104
88
 
105
89
 
106
- def select_config(config: object, names: list[str]) -> dict[str, Any]:
90
+ def select_config(config: Any, names: list[str]) -> dict[str, Any]:
107
91
  """Select the given parameters from the configuration object.
108
92
 
109
93
  This function selects the given parameters from the configuration object
110
94
  and returns a new configuration object containing only the selected parameters.
111
95
 
112
96
  Args:
113
- config (object): The configuration object to select parameters from.
97
+ config (Any): The configuration object to select parameters from.
114
98
  names (list[str]): The names of the parameters to select.
115
99
 
116
100
  Returns:
@@ -120,7 +104,7 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
120
104
  if not isinstance(config, DictConfig):
121
105
  config = OmegaConf.structured(config)
122
106
 
123
- return {name: _get(config, name) for name in names} # type: ignore
107
+ return {name: _get(config, name) for name in names}
124
108
 
125
109
 
126
110
  def _get(config: DictConfig, name: str) -> Any:
@@ -132,8 +116,7 @@ def _get(config: DictConfig, name: str) -> Any:
132
116
  return _get(config.get(prefix), name)
133
117
 
134
118
 
135
- def select_overrides(config: object) -> dict[str, Any]:
119
+ def select_overrides(config: object, overrides: list[str]) -> dict[str, Any]:
136
120
  """Select the given overrides from the configuration object."""
137
- overrides = get_overrides()
138
121
  names = [override.split("=")[0].strip() for override in overrides]
139
122
  return select_config(config, names)
@@ -12,8 +12,9 @@ import mlflow
12
12
  import mlflow.artifacts
13
13
  from hydra.core.hydra_config import HydraConfig
14
14
 
15
- from hydraflow.mlflow import log_params
16
- from hydraflow.utils import get_artifact_dir
15
+ from hydraflow.core.io import get_artifact_dir
16
+
17
+ from .mlflow import log_params, log_text
17
18
 
18
19
  if TYPE_CHECKING:
19
20
  from collections.abc import Iterator
@@ -55,11 +56,11 @@ def log_run(
55
56
  log_params(config, synchronous=synchronous)
56
57
 
57
58
  hc = HydraConfig.get()
58
- output_dir = Path(hc.runtime.output_dir)
59
+ hydra_dir = Path(hc.runtime.output_dir)
59
60
 
60
61
  # Save '.hydra' config directory.
61
- output_subdir = output_dir / (hc.output_subdir or "")
62
- mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
62
+ hydra_subdir = hydra_dir / (hc.output_subdir or "")
63
+ mlflow.log_artifacts(hydra_subdir.as_posix(), hc.output_subdir)
63
64
 
64
65
  try:
65
66
  yield
@@ -70,43 +71,14 @@ def log_run(
70
71
  raise
71
72
 
72
73
  finally:
73
- log_text(output_dir)
74
-
75
-
76
- def log_text(directory: Path, pattern: str = "*.log") -> None:
77
- """Log text files in the given directory as artifacts.
78
-
79
- Append the text files to the existing text file in the artifact directory.
80
-
81
- Args:
82
- directory (Path): The directory to find the logs in.
83
- pattern (str): The pattern to match the logs.
84
-
85
- """
86
- artifact_dir = get_artifact_dir()
87
-
88
- for file in directory.glob(pattern):
89
- if not file.is_file():
90
- continue
91
-
92
- file_artifact = artifact_dir / file.name
93
- if file_artifact.exists():
94
- text = file_artifact.read_text()
95
- if not text.endswith("\n"):
96
- text += "\n"
97
- else:
98
- text = ""
99
-
100
- text += file.read_text()
101
- mlflow.log_text(text, file.name)
74
+ log_text(hydra_dir)
102
75
 
103
76
 
104
77
  @contextmanager
105
- def start_run( # noqa: PLR0913
78
+ def start_run(
106
79
  config: object,
107
80
  *,
108
81
  chdir: bool = False,
109
- run: Run | None = None,
110
82
  run_id: str | None = None,
111
83
  experiment_id: str | None = None,
112
84
  run_name: str | None = None,
@@ -126,7 +98,6 @@ def start_run( # noqa: PLR0913
126
98
  config (object): The configuration object to log parameters from.
127
99
  chdir (bool): Whether to change the current working directory to the
128
100
  artifact directory of the current run. Defaults to False.
129
- run (Run | None): The existing run. Defaults to None.
130
101
  run_id (str | None): The existing run ID. Defaults to None.
131
102
  experiment_id (str | None): The experiment ID. Defaults to None.
132
103
  run_name (str | None): The name of the run. Defaults to None.
@@ -142,20 +113,7 @@ def start_run( # noqa: PLR0913
142
113
  Yields:
143
114
  Run: An MLflow Run object representing the started run.
144
115
 
145
- Example:
146
- with start_run(config) as run:
147
- # Perform operations within the MLflow run context
148
- pass
149
-
150
- See Also:
151
- - `mlflow.start_run`: The MLflow function to start a run directly.
152
- - `log_run`: A context manager to log parameters and manage the MLflow
153
- run context.
154
-
155
116
  """
156
- if run:
157
- run_id = run.info.run_id
158
-
159
117
  with (
160
118
  mlflow.start_run(
161
119
  run_id=run_id,
@@ -12,46 +12,42 @@ import mlflow
12
12
  import mlflow.artifacts
13
13
  from hydra.core.hydra_config import HydraConfig
14
14
  from mlflow.entities import Run
15
- from omegaconf import DictConfig, OmegaConf
15
+ from omegaconf import DictConfig, ListConfig, OmegaConf
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from collections.abc import Iterable
19
19
 
20
20
 
21
- def get_artifact_dir(run: Run | None = None, uri: str | None = None) -> Path:
21
+ def file_uri_to_path(uri: str) -> Path:
22
+ """Convert a file URI to a local path."""
23
+ if not uri.startswith("file:"):
24
+ return Path(uri)
25
+
26
+ path = urllib.parse.urlparse(uri).path
27
+ return Path(urllib.request.url2pathname(path)) # for Windows
28
+
29
+
30
+ def get_artifact_dir(run: Run | None = None) -> Path:
22
31
  """Retrieve the artifact directory for the given run.
23
32
 
24
33
  This function uses MLflow to get the artifact directory for the given run.
25
34
 
26
35
  Args:
27
36
  run (Run | None): The run object. Defaults to None.
28
- uri (str | None): The URI of the artifact. Defaults to None.
29
37
 
30
38
  Returns:
31
39
  The local path to the directory where the artifacts are downloaded.
32
40
 
33
41
  """
34
- if run is not None and uri is not None:
35
- raise ValueError("Cannot provide both run and uri")
36
-
37
- if run is None and uri is None:
42
+ if run is None:
38
43
  uri = mlflow.get_artifact_uri()
39
- elif run:
44
+ else:
40
45
  uri = run.info.artifact_uri
41
46
 
42
47
  if not isinstance(uri, str):
43
48
  raise NotImplementedError
44
49
 
45
- if uri.startswith("file:"):
46
- return file_uri_to_path(uri)
47
-
48
- return Path(uri)
49
-
50
-
51
- def file_uri_to_path(uri: str) -> Path:
52
- """Convert a file URI to a local path."""
53
- path = urllib.parse.urlparse(uri).path
54
- return Path(urllib.request.url2pathname(path)) # for Windows
50
+ return file_uri_to_path(uri)
55
51
 
56
52
 
57
53
  def get_artifact_path(run: Run | None, path: str) -> Path:
@@ -123,12 +119,7 @@ def load_config(run: Run) -> DictConfig:
123
119
  return OmegaConf.load(path) # type: ignore
124
120
 
125
121
 
126
- def get_overrides() -> list[str]:
127
- """Retrieve the overrides for the current run."""
128
- return list(HydraConfig.get().overrides.task) # ListConifg -> list
129
-
130
-
131
- def load_overrides(run: Run) -> list[str]:
122
+ def load_overrides(run: Run) -> ListConfig:
132
123
  """Load the overrides for a given run.
133
124
 
134
125
  This function loads the overrides for the provided Run instance
@@ -137,15 +128,15 @@ def load_overrides(run: Run) -> list[str]:
137
128
  `.hydra/overrides.yaml` is not found in the run's artifact directory.
138
129
 
139
130
  Args:
140
- run (Run): The Run instance for which to load the overrides.
131
+ run (Run): The Run instance for which to load the configuration.
141
132
 
142
133
  Returns:
143
- The loaded overrides as a list of strings. Returns an empty list
144
- if the overrides file is not found.
134
+ The loaded configuration as a DictConfig object. Returns an empty
135
+ DictConfig if the configuration file is not found.
145
136
 
146
137
  """
147
138
  path = get_artifact_dir(run) / ".hydra/overrides.yaml"
148
- return [str(x) for x in OmegaConf.load(path)]
139
+ return OmegaConf.load(path) # type: ignore
149
140
 
150
141
 
151
142
  def remove_run(run: Run | Iterable[Run]) -> None:
hydraflow/core/main.py ADDED
@@ -0,0 +1,164 @@
1
+ """Integration of MLflow experiment tracking with Hydra configuration management.
2
+
3
+ This module provides decorators and utilities to seamlessly combine Hydra's
4
+ configuration management with MLflow's experiment tracking capabilities. It
5
+ enables automatic run deduplication, configuration storage, and experiment
6
+ management.
7
+
8
+ The main functionality is provided through the `main` decorator, which can be
9
+ used to wrap experiment entry points. This decorator handles:
10
+
11
+ - Configuration management via Hydra
12
+ - Experiment tracking via MLflow
13
+ - Run deduplication based on configurations
14
+ - Working directory management
15
+ - Automatic configuration storage
16
+
17
+ Example:
18
+ ```python
19
+ from dataclasses import dataclass
20
+ from mlflow.entities import Run
21
+
22
+ @dataclass
23
+ class Config:
24
+ learning_rate: float
25
+ batch_size: int
26
+
27
+ @main(Config)
28
+ def train(run: Run, config: Config):
29
+ # Your training code here
30
+ pass
31
+ ```
32
+
33
+ """
34
+
35
+ from __future__ import annotations
36
+
37
+ from functools import wraps
38
+ from typing import TYPE_CHECKING, TypeVar
39
+
40
+ import hydra
41
+ import mlflow
42
+ from hydra.core.config_store import ConfigStore
43
+ from hydra.core.hydra_config import HydraConfig
44
+ from mlflow.entities import RunStatus
45
+ from omegaconf import OmegaConf
46
+
47
+ import hydraflow
48
+ from hydraflow.core.io import file_uri_to_path
49
+
50
+ if TYPE_CHECKING:
51
+ from collections.abc import Callable
52
+ from pathlib import Path
53
+ from typing import Any
54
+
55
+ from mlflow.entities import Run
56
+
57
+ FINISHED = RunStatus.to_string(RunStatus.FINISHED)
58
+
59
+ T = TypeVar("T")
60
+
61
+
62
+ def main(
63
+ node: T | type[T],
64
+ config_name: str = "config",
65
+ *,
66
+ chdir: bool = False,
67
+ force_new_run: bool = False,
68
+ match_overrides: bool = False,
69
+ rerun_finished: bool = False,
70
+ ):
71
+ """Decorator for configuring and running MLflow experiments with Hydra.
72
+
73
+ This decorator combines Hydra configuration management with MLflow experiment
74
+ tracking. It automatically handles run deduplication and configuration storage.
75
+
76
+ Args:
77
+ node: Configuration node class or instance defining the structure of the
78
+ configuration.
79
+ config_name: Name of the configuration. Defaults to "config".
80
+ chdir: If True, changes working directory to the artifact directory
81
+ of the run. Defaults to False.
82
+ force_new_run: If True, always creates a new MLflow run instead of
83
+ reusing existing ones. Defaults to False.
84
+ match_overrides: If True, matches runs based on Hydra CLI overrides
85
+ instead of full config. Defaults to False.
86
+ rerun_finished: If True, allows rerunning completed runs. Defaults to
87
+ False.
88
+
89
+ """
90
+
91
+ def decorator(app: Callable[[Run, T], None]) -> Callable[[], None]:
92
+ ConfigStore.instance().store(config_name, node)
93
+
94
+ @hydra.main(config_name=config_name, version_base=None)
95
+ @wraps(app)
96
+ def inner_decorator(config: T) -> None:
97
+ hc = HydraConfig.get()
98
+ experiment = mlflow.set_experiment(hc.job.name)
99
+
100
+ if force_new_run:
101
+ run_id = None
102
+ else:
103
+ uri = experiment.artifact_location
104
+ overrides = hc.overrides.task if match_overrides else None
105
+ run_id = get_run_id(uri, config, overrides)
106
+
107
+ if run_id and not rerun_finished:
108
+ run = mlflow.get_run(run_id)
109
+ if run.info.status == FINISHED:
110
+ return
111
+
112
+ with hydraflow.start_run(config, run_id=run_id, chdir=chdir) as run:
113
+ app(run, config)
114
+
115
+ return inner_decorator
116
+
117
+ return decorator
118
+
119
+
120
+ def get_run_id(uri: str, config: Any, overrides: list[str] | None) -> str | None:
121
+ """Try to get the run ID for the given configuration.
122
+
123
+ If the run is not found, the function will return None.
124
+
125
+ Args:
126
+ uri (str): The URI of the experiment.
127
+ config (object): The configuration object.
128
+ overrides (list[str] | None): The task overrides.
129
+
130
+ Returns:
131
+ The run ID for the given configuration or overrides. Returns None if
132
+ no run ID is found.
133
+
134
+ """
135
+ for run_dir in file_uri_to_path(uri).iterdir():
136
+ if run_dir.is_dir() and equals(run_dir, config, overrides):
137
+ return run_dir.name
138
+
139
+ return None
140
+
141
+
142
+ def equals(run_dir: Path, config: Any, overrides: list[str] | None) -> bool:
143
+ """Check if the run directory matches the given configuration or overrides.
144
+
145
+ Args:
146
+ run_dir (Path): The run directory.
147
+ config (object): The configuration object.
148
+ overrides (list[str] | None): The task overrides.
149
+
150
+ Returns:
151
+ True if the run directory matches the given configuration or overrides,
152
+ False otherwise.
153
+
154
+ """
155
+ if overrides is None:
156
+ path = run_dir / "artifacts/.hydra/config.yaml"
157
+ else:
158
+ path = run_dir / "artifacts/.hydra/overrides.yaml"
159
+ config = overrides
160
+
161
+ if not path.exists():
162
+ return False
163
+
164
+ return OmegaConf.load(path) == config