hydraflow 0.7.5__tar.gz → 0.8.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (83) hide show
  1. {hydraflow-0.7.5 → hydraflow-0.8.0}/PKG-INFO +2 -2
  2. {hydraflow-0.7.5 → hydraflow-0.8.0}/README.md +1 -1
  3. {hydraflow-0.7.5 → hydraflow-0.8.0}/apps/quickstart.py +5 -2
  4. {hydraflow-0.7.5 → hydraflow-0.8.0}/docs/usage/quickstart.md +0 -12
  5. {hydraflow-0.7.5 → hydraflow-0.8.0}/pyproject.toml +10 -14
  6. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/__init__.py +1 -16
  7. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/config.py +10 -27
  8. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/context.py +6 -49
  9. hydraflow-0.8.0/src/hydraflow/main.py +162 -0
  10. hydraflow-0.8.0/src/hydraflow/mlflow.py +167 -0
  11. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/param.py +2 -2
  12. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/run_collection.py +10 -156
  13. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/run_data.py +4 -2
  14. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/utils.py +19 -28
  15. hydraflow-0.8.0/tests/config/test_config.py +54 -0
  16. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/config/test_params.py +11 -27
  17. hydraflow-0.8.0/tests/conftest.py +46 -0
  18. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/context/chdir.py +5 -2
  19. hydraflow-0.7.5/tests/context/logging.py → hydraflow-0.8.0/tests/context/log_run.py +11 -8
  20. hydraflow-0.7.5/tests/context/context.py → hydraflow-0.8.0/tests/context/start_run.py +7 -13
  21. hydraflow-0.7.5/tests/context/test_logging.py → hydraflow-0.8.0/tests/context/test_log_run.py +23 -19
  22. hydraflow-0.7.5/tests/context/test_context.py → hydraflow-0.8.0/tests/context/test_start_run.py +13 -5
  23. hydraflow-0.7.5/tests/main/base.py → hydraflow-0.8.0/tests/main/default.py +2 -1
  24. hydraflow-0.8.0/tests/main/match_overrides.py +24 -0
  25. hydraflow-0.7.5/tests/main/restart.py → hydraflow-0.8.0/tests/main/rerun_finished.py +9 -2
  26. hydraflow-0.7.5/tests/main/test_base.py → hydraflow-0.8.0/tests/main/test_default.py +10 -6
  27. hydraflow-0.8.0/tests/main/test_match_overrides.py +22 -0
  28. hydraflow-0.7.5/tests/main/test_restart.py → hydraflow-0.8.0/tests/main/test_rerun_finished.py +1 -1
  29. hydraflow-0.7.5/tests/main/test_skip.py → hydraflow-0.8.0/tests/main/test_skip_finished.py +1 -1
  30. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/params.py +5 -2
  31. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/test_params.py +14 -3
  32. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/filter.py +6 -3
  33. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/test_collection.py +26 -69
  34. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/test_data.py +2 -2
  35. hydraflow-0.8.0/tests/run/test_filter.py +41 -0
  36. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/test_info.py +2 -2
  37. hydraflow-0.8.0/tests/run/test_values.py +34 -0
  38. hydraflow-0.7.5/tests/run/run.py → hydraflow-0.8.0/tests/run/values.py +7 -9
  39. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/test_mlflow.py +8 -26
  40. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/test_run.py +2 -2
  41. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/test_utils.py +1 -17
  42. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/utils.py +4 -5
  43. hydraflow-0.7.5/src/hydraflow/main.py +0 -54
  44. hydraflow-0.7.5/src/hydraflow/mlflow.py +0 -280
  45. hydraflow-0.7.5/tests/config/overrides.py +0 -32
  46. hydraflow-0.7.5/tests/config/test_config.py +0 -29
  47. hydraflow-0.7.5/tests/config/test_overrides.py +0 -25
  48. hydraflow-0.7.5/tests/conftest.py +0 -81
  49. hydraflow-0.7.5/tests/context/rerun.py +0 -40
  50. hydraflow-0.7.5/tests/context/test_rerun.py +0 -31
  51. hydraflow-0.7.5/tests/run/test_filter.py +0 -19
  52. hydraflow-0.7.5/tests/run/test_run.py +0 -54
  53. {hydraflow-0.7.5 → hydraflow-0.8.0}/.devcontainer/devcontainer.json +0 -0
  54. {hydraflow-0.7.5 → hydraflow-0.8.0}/.devcontainer/postCreate.sh +0 -0
  55. {hydraflow-0.7.5 → hydraflow-0.8.0}/.devcontainer/starship.toml +0 -0
  56. {hydraflow-0.7.5 → hydraflow-0.8.0}/.gitattributes +0 -0
  57. {hydraflow-0.7.5 → hydraflow-0.8.0}/.github/workflows/ci.yaml +0 -0
  58. {hydraflow-0.7.5 → hydraflow-0.8.0}/.github/workflows/docs.yaml +0 -0
  59. {hydraflow-0.7.5 → hydraflow-0.8.0}/.gitignore +0 -0
  60. {hydraflow-0.7.5 → hydraflow-0.8.0}/LICENSE +0 -0
  61. {hydraflow-0.7.5 → hydraflow-0.8.0}/docs/index.md +0 -0
  62. {hydraflow-0.7.5 → hydraflow-0.8.0}/hydraflow.yaml +0 -0
  63. {hydraflow-0.7.5 → hydraflow-0.8.0}/mkdocs.yaml +0 -0
  64. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/cli.py +0 -0
  65. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/py.typed +0 -0
  66. {hydraflow-0.7.5 → hydraflow-0.8.0}/src/hydraflow/run_info.py +0 -0
  67. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/__init__.py +0 -0
  68. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/__init__.py +0 -0
  69. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/conftest.py +0 -0
  70. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/test_run.py +0 -0
  71. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/test_show.py +0 -0
  72. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/cli/test_version.py +0 -0
  73. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/config/__init__.py +0 -0
  74. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/context/__init__.py +0 -0
  75. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/context/test_chdir.py +0 -0
  76. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/main/__init__.py +0 -0
  77. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/main/force_new_run.py +0 -0
  78. /hydraflow-0.7.5/tests/main/skip.py → /hydraflow-0.8.0/tests/main/skip_finished.py +0 -0
  79. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/main/test_force_new_run.py +0 -0
  80. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/__init__.py +0 -0
  81. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/param/test_param.py +0 -0
  82. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/run/__init__.py +0 -0
  83. {hydraflow-0.7.5 → hydraflow-0.8.0}/tests/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hydraflow
3
- Version: 0.7.5
3
+ Version: 0.8.0
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -108,7 +108,7 @@ class MySQLConfig:
108
108
  cs = ConfigStore.instance()
109
109
  cs.store(name="config", node=MySQLConfig)
110
110
 
111
- @hydra.main(version_base=None, config_name="config")
111
+ @hydra.main(config_name="config", version_base=None)
112
112
  def my_app(cfg: MySQLConfig) -> None:
113
113
  # Set experiment by Hydra job name.
114
114
  hydraflow.set_experiment()
@@ -63,7 +63,7 @@ class MySQLConfig:
63
63
  cs = ConfigStore.instance()
64
64
  cs.store(name="config", node=MySQLConfig)
65
65
 
66
- @hydra.main(version_base=None, config_name="config")
66
+ @hydra.main(config_name="config", version_base=None)
67
67
  def my_app(cfg: MySQLConfig) -> None:
68
68
  # Set experiment by Hydra job name.
69
69
  hydraflow.set_experiment()
@@ -2,7 +2,9 @@ import logging
2
2
  from dataclasses import dataclass
3
3
 
4
4
  import hydra
5
+ import mlflow
5
6
  from hydra.core.config_store import ConfigStore
7
+ from hydra.core.hydra_config import HydraConfig
6
8
 
7
9
  import hydraflow
8
10
 
@@ -19,9 +21,10 @@ cs = ConfigStore.instance()
19
21
  cs.store(name="config", node=Config)
20
22
 
21
23
 
22
- @hydra.main(version_base=None, config_name="config")
24
+ @hydra.main(config_name="config", version_base=None)
23
25
  def app(cfg: Config) -> None:
24
- hydraflow.set_experiment()
26
+ hc = HydraConfig.get()
27
+ mlflow.set_experiment(hc.job.name)
25
28
 
26
29
  with hydraflow.start_run(cfg):
27
30
  log.info(f"{cfg.width=}, {cfg.height=}")
@@ -117,18 +117,6 @@ $ python apps/quickstart.py -m width=400,600 height=100,200,300
117
117
  >>> print(run.data.params)
118
118
  ```
119
119
 
120
- ### Map runs
121
-
122
- ```pycon exec="1" source="console" session="quickstart"
123
- >>> params = rc.map(lambda x: x.data.params)
124
- >>> for p in params:
125
- ... print(p)
126
- ```
127
-
128
- ```pycon exec="1" source="console" session="quickstart"
129
- >>> list(rc.map_id(print))
130
- ```
131
-
132
120
  ### Group runs
133
121
 
134
122
  ```pycon exec="1" source="console" session="quickstart"
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.7.5"
7
+ version = "0.8.0"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -72,24 +72,20 @@ ignore = [
72
72
  "D203",
73
73
  "D213",
74
74
  "EM101",
75
+ "FBT001",
76
+ "FBT002",
75
77
  "PGH003",
78
+ "PLR0911",
79
+ "PLR0913",
76
80
  "PLR1704",
81
+ "PLR2004",
82
+ "SIM102",
83
+ "SIM108",
77
84
  "TRY003",
78
85
  ]
79
86
 
80
87
  [tool.ruff.lint.per-file-ignores]
81
- "tests/*" = [
82
- "A001",
83
- "ANN",
84
- "ARG",
85
- "D",
86
- "FBT",
87
- "PLR",
88
- "PT",
89
- "S",
90
- "SIM108",
91
- "SLF",
92
- ]
88
+ "tests/*" = ["A001", "ANN", "ARG", "D", "FBT", "PD", "PLR", "PT", "S", "SLF"]
93
89
  "apps/*.py" = ["D", "G", "INP"]
94
- "src/hydraflow/main.py" = ["ANN201", "D401", "PLR0913"]
90
+ "src/hydraflow/main.py" = ["ANN201", "D401"]
95
91
  "src/hydraflow/cli.py" = ["ANN", "D"]
@@ -1,23 +1,14 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
- from hydraflow.config import select_config, select_overrides
4
3
  from hydraflow.context import chdir_artifact, log_run, start_run
5
4
  from hydraflow.main import main
6
- from hydraflow.mlflow import (
7
- list_run_ids,
8
- list_run_paths,
9
- list_runs,
10
- search_runs,
11
- set_experiment,
12
- )
5
+ from hydraflow.mlflow import list_run_ids, list_run_paths, list_runs
13
6
  from hydraflow.run_collection import RunCollection
14
7
  from hydraflow.utils import (
15
8
  get_artifact_dir,
16
9
  get_artifact_path,
17
10
  get_hydra_output_dir,
18
- get_overrides,
19
11
  load_config,
20
- load_overrides,
21
12
  remove_run,
22
13
  )
23
14
 
@@ -27,18 +18,12 @@ __all__ = [
27
18
  "get_artifact_dir",
28
19
  "get_artifact_path",
29
20
  "get_hydra_output_dir",
30
- "get_overrides",
31
21
  "list_run_ids",
32
22
  "list_run_paths",
33
23
  "list_runs",
34
24
  "load_config",
35
- "load_overrides",
36
25
  "log_run",
37
26
  "main",
38
27
  "remove_run",
39
- "search_runs",
40
- "select_config",
41
- "select_overrides",
42
- "set_experiment",
43
28
  "start_run",
44
29
  ]
@@ -6,35 +6,19 @@ from typing import TYPE_CHECKING
6
6
 
7
7
  from omegaconf import DictConfig, ListConfig, OmegaConf
8
8
 
9
- from hydraflow.utils import get_overrides
10
-
11
9
  if TYPE_CHECKING:
12
10
  from collections.abc import Iterator
13
11
  from typing import Any
14
12
 
15
13
 
16
- def collect_params(config: object) -> dict[str, Any]:
17
- """Iterate over parameters and collect them into a dictionary.
18
-
19
- Args:
20
- config (object): The configuration object to iterate over.
21
- prefix (str): The prefix to prepend to the parameter keys.
22
-
23
- Returns:
24
- dict[str, Any]: A dictionary of collected parameters.
25
-
26
- """
27
- return dict(iter_params(config))
28
-
29
-
30
- def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
14
+ def iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
31
15
  """Recursively iterate over the parameters in the given configuration object.
32
16
 
33
17
  This function traverses the configuration object and yields key-value pairs
34
18
  representing the parameters. The keys are prefixed with the provided prefix.
35
19
 
36
20
  Args:
37
- config (object): The configuration object to iterate over. This can be a
21
+ config (Any): The configuration object to iterate over. This can be a
38
22
  dictionary, list, DictConfig, or ListConfig.
39
23
  prefix (str): The prefix to prepend to the parameter keys.
40
24
  Defaults to an empty string.
@@ -50,7 +34,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
50
34
  config = _from_dotlist(config)
51
35
 
52
36
  if not isinstance(config, DictConfig | ListConfig):
53
- config = OmegaConf.create(config) # type: ignore
37
+ config = OmegaConf.create(config)
54
38
 
55
39
  yield from _iter_params(config, prefix)
56
40
 
@@ -65,7 +49,7 @@ def _from_dotlist(config: list[str]) -> dict[str, str]:
65
49
  return result
66
50
 
67
51
 
68
- def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
52
+ def _iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
69
53
  if isinstance(config, DictConfig):
70
54
  for key, value in config.items():
71
55
  if _is_param(value):
@@ -83,12 +67,12 @@ def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
83
67
  yield from _iter_params(value, f"{prefix}{index}.")
84
68
 
85
69
 
86
- def _is_param(value: object) -> bool:
70
+ def _is_param(value: Any) -> bool:
87
71
  """Check if the given value is a parameter."""
88
72
  if isinstance(value, DictConfig):
89
73
  return False
90
74
 
91
- if isinstance(value, ListConfig): # noqa: SIM102
75
+ if isinstance(value, ListConfig):
92
76
  if any(isinstance(v, DictConfig | ListConfig) for v in value):
93
77
  return False
94
78
 
@@ -103,14 +87,14 @@ def _convert(value: Any) -> Any:
103
87
  return value
104
88
 
105
89
 
106
- def select_config(config: object, names: list[str]) -> dict[str, Any]:
90
+ def select_config(config: Any, names: list[str]) -> dict[str, Any]:
107
91
  """Select the given parameters from the configuration object.
108
92
 
109
93
  This function selects the given parameters from the configuration object
110
94
  and returns a new configuration object containing only the selected parameters.
111
95
 
112
96
  Args:
113
- config (object): The configuration object to select parameters from.
97
+ config (Any): The configuration object to select parameters from.
114
98
  names (list[str]): The names of the parameters to select.
115
99
 
116
100
  Returns:
@@ -120,7 +104,7 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
120
104
  if not isinstance(config, DictConfig):
121
105
  config = OmegaConf.structured(config)
122
106
 
123
- return {name: _get(config, name) for name in names} # type: ignore
107
+ return {name: _get(config, name) for name in names}
124
108
 
125
109
 
126
110
  def _get(config: DictConfig, name: str) -> Any:
@@ -132,8 +116,7 @@ def _get(config: DictConfig, name: str) -> Any:
132
116
  return _get(config.get(prefix), name)
133
117
 
134
118
 
135
- def select_overrides(config: object) -> dict[str, Any]:
119
+ def select_overrides(config: object, overrides: list[str]) -> dict[str, Any]:
136
120
  """Select the given overrides from the configuration object."""
137
- overrides = get_overrides()
138
121
  names = [override.split("=")[0].strip() for override in overrides]
139
122
  return select_config(config, names)
@@ -12,7 +12,7 @@ import mlflow
12
12
  import mlflow.artifacts
13
13
  from hydra.core.hydra_config import HydraConfig
14
14
 
15
- from hydraflow.mlflow import log_params
15
+ from hydraflow.mlflow import log_params, log_text
16
16
  from hydraflow.utils import get_artifact_dir
17
17
 
18
18
  if TYPE_CHECKING:
@@ -55,11 +55,11 @@ def log_run(
55
55
  log_params(config, synchronous=synchronous)
56
56
 
57
57
  hc = HydraConfig.get()
58
- output_dir = Path(hc.runtime.output_dir)
58
+ hydra_dir = Path(hc.runtime.output_dir)
59
59
 
60
60
  # Save '.hydra' config directory.
61
- output_subdir = output_dir / (hc.output_subdir or "")
62
- mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
61
+ hydra_subdir = hydra_dir / (hc.output_subdir or "")
62
+ mlflow.log_artifacts(hydra_subdir.as_posix(), hc.output_subdir)
63
63
 
64
64
  try:
65
65
  yield
@@ -70,43 +70,14 @@ def log_run(
70
70
  raise
71
71
 
72
72
  finally:
73
- log_text(output_dir)
74
-
75
-
76
- def log_text(directory: Path, pattern: str = "*.log") -> None:
77
- """Log text files in the given directory as artifacts.
78
-
79
- Append the text files to the existing text file in the artifact directory.
80
-
81
- Args:
82
- directory (Path): The directory to find the logs in.
83
- pattern (str): The pattern to match the logs.
84
-
85
- """
86
- artifact_dir = get_artifact_dir()
87
-
88
- for file in directory.glob(pattern):
89
- if not file.is_file():
90
- continue
91
-
92
- file_artifact = artifact_dir / file.name
93
- if file_artifact.exists():
94
- text = file_artifact.read_text()
95
- if not text.endswith("\n"):
96
- text += "\n"
97
- else:
98
- text = ""
99
-
100
- text += file.read_text()
101
- mlflow.log_text(text, file.name)
73
+ log_text(hydra_dir)
102
74
 
103
75
 
104
76
  @contextmanager
105
- def start_run( # noqa: PLR0913
77
+ def start_run(
106
78
  config: object,
107
79
  *,
108
80
  chdir: bool = False,
109
- run: Run | None = None,
110
81
  run_id: str | None = None,
111
82
  experiment_id: str | None = None,
112
83
  run_name: str | None = None,
@@ -126,7 +97,6 @@ def start_run( # noqa: PLR0913
126
97
  config (object): The configuration object to log parameters from.
127
98
  chdir (bool): Whether to change the current working directory to the
128
99
  artifact directory of the current run. Defaults to False.
129
- run (Run | None): The existing run. Defaults to None.
130
100
  run_id (str | None): The existing run ID. Defaults to None.
131
101
  experiment_id (str | None): The experiment ID. Defaults to None.
132
102
  run_name (str | None): The name of the run. Defaults to None.
@@ -142,20 +112,7 @@ def start_run( # noqa: PLR0913
142
112
  Yields:
143
113
  Run: An MLflow Run object representing the started run.
144
114
 
145
- Example:
146
- with start_run(config) as run:
147
- # Perform operations within the MLflow run context
148
- pass
149
-
150
- See Also:
151
- - `mlflow.start_run`: The MLflow function to start a run directly.
152
- - `log_run`: A context manager to log parameters and manage the MLflow
153
- run context.
154
-
155
115
  """
156
- if run:
157
- run_id = run.info.run_id
158
-
159
116
  with (
160
117
  mlflow.start_run(
161
118
  run_id=run_id,
@@ -0,0 +1,162 @@
1
+ """Integration of MLflow experiment tracking with Hydra configuration management.
2
+
3
+ This module provides decorators and utilities to seamlessly combine Hydra's
4
+ configuration management with MLflow's experiment tracking capabilities. It
5
+ enables automatic run deduplication, configuration storage, and experiment
6
+ management.
7
+
8
+ The main functionality is provided through the `main` decorator, which can be
9
+ used to wrap experiment entry points. This decorator handles:
10
+ - Configuration management via Hydra
11
+ - Experiment tracking via MLflow
12
+ - Run deduplication based on configurations
13
+ - Working directory management
14
+ - Automatic configuration storage
15
+
16
+ Example:
17
+ ```python
18
+ from dataclasses import dataclass
19
+ from mlflow.entities import Run
20
+
21
+ @dataclass
22
+ class Config:
23
+ learning_rate: float
24
+ batch_size: int
25
+
26
+ @main(Config)
27
+ def train(run: Run, config: Config):
28
+ # Your training code here
29
+ pass
30
+ ```
31
+
32
+ """
33
+
34
+ from __future__ import annotations
35
+
36
+ from functools import wraps
37
+ from typing import TYPE_CHECKING, TypeVar
38
+
39
+ import hydra
40
+ import mlflow
41
+ from hydra.core.config_store import ConfigStore
42
+ from hydra.core.hydra_config import HydraConfig
43
+ from mlflow.entities import RunStatus
44
+ from omegaconf import OmegaConf
45
+
46
+ import hydraflow
47
+ from hydraflow.utils import file_uri_to_path
48
+
49
+ if TYPE_CHECKING:
50
+ from collections.abc import Callable
51
+ from pathlib import Path
52
+
53
+ from mlflow.entities import Run
54
+
55
+ FINISHED = RunStatus.to_string(RunStatus.FINISHED)
56
+
57
+ T = TypeVar("T")
58
+
59
+
60
+ def main(
61
+ node: T | type[T],
62
+ config_name: str = "config",
63
+ *,
64
+ chdir: bool = False,
65
+ force_new_run: bool = False,
66
+ match_overrides: bool = False,
67
+ rerun_finished: bool = False,
68
+ ):
69
+ """Decorator for configuring and running MLflow experiments with Hydra.
70
+
71
+ This decorator combines Hydra configuration management with MLflow experiment
72
+ tracking. It automatically handles run deduplication and configuration storage.
73
+
74
+ Args:
75
+ node: Configuration node class or instance defining the structure of the
76
+ configuration.
77
+ config_name: Name of the configuration. Defaults to "config".
78
+ chdir: If True, changes working directory to the artifact directory
79
+ of the run. Defaults to False.
80
+ force_new_run: If True, always creates a new MLflow run instead of
81
+ reusing existing ones. Defaults to False.
82
+ match_overrides: If True, matches runs based on Hydra CLI overrides
83
+ instead of full config. Defaults to False.
84
+ rerun_finished: If True, allows rerunning completed runs. Defaults to
85
+ False.
86
+
87
+ """
88
+
89
+ def decorator(app: Callable[[Run, T], None]) -> Callable[[], None]:
90
+ ConfigStore.instance().store(config_name, node)
91
+
92
+ @hydra.main(config_name=config_name, version_base=None)
93
+ @wraps(app)
94
+ def inner_decorator(config: T) -> None:
95
+ hc = HydraConfig.get()
96
+ experiment = mlflow.set_experiment(hc.job.name)
97
+
98
+ if force_new_run:
99
+ run_id = None
100
+ else:
101
+ uri = experiment.artifact_location
102
+ overrides = hc.overrides.task if match_overrides else None
103
+ run_id = get_run_id(uri, config, overrides)
104
+
105
+ if run_id and not rerun_finished:
106
+ run = mlflow.get_run(run_id)
107
+ if run.info.status == FINISHED:
108
+ return
109
+
110
+ with hydraflow.start_run(config, run_id=run_id, chdir=chdir) as run:
111
+ app(run, config)
112
+
113
+ return inner_decorator
114
+
115
+ return decorator
116
+
117
+
118
+ def get_run_id(uri: str, config: object, overrides: list[str] | None) -> str | None:
119
+ """Try to get the run ID for the given configuration.
120
+
121
+ If the run is not found, the function will return None.
122
+
123
+ Args:
124
+ uri (str): The URI of the experiment.
125
+ config (object): The configuration object.
126
+ overrides (list[str] | None): The task overrides.
127
+
128
+ Returns:
129
+ The run ID for the given configuration or overrides. Returns None if
130
+ no run ID is found.
131
+
132
+ """
133
+ for run_dir in file_uri_to_path(uri).iterdir():
134
+ if run_dir.is_dir() and equals(run_dir, config, overrides):
135
+ return run_dir.name
136
+
137
+ return None
138
+
139
+
140
+ def equals(run_dir: Path, config: object, overrides: list[str] | None) -> bool:
141
+ """Check if the run directory matches the given configuration or overrides.
142
+
143
+ Args:
144
+ run_dir (Path): The run directory.
145
+ config (object): The configuration object.
146
+ overrides (list[str] | None): The task overrides.
147
+
148
+ Returns:
149
+ True if the run directory matches the given configuration or overrides,
150
+ False otherwise.
151
+
152
+ """
153
+ if overrides is None:
154
+ path = run_dir / "artifacts/.hydra/config.yaml"
155
+ else:
156
+ path = run_dir / "artifacts/.hydra/overrides.yaml"
157
+ config = overrides
158
+
159
+ if not path.exists():
160
+ return False
161
+
162
+ return OmegaConf.load(path) == config
@@ -0,0 +1,167 @@
1
+ """Integration of MLflow experiment tracking with Hydra configuration management.
2
+
3
+ This module provides functions to log parameters from Hydra configuration objects
4
+ to MLflow, set experiments, and manage tracking URIs. It integrates Hydra's
5
+ configuration management with MLflow's experiment tracking capabilities.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import TYPE_CHECKING
11
+
12
+ import joblib
13
+ import mlflow
14
+ import mlflow.artifacts
15
+
16
+ from hydraflow.config import iter_params
17
+ from hydraflow.run_collection import RunCollection
18
+ from hydraflow.utils import file_uri_to_path, get_artifact_dir
19
+
20
+ if TYPE_CHECKING:
21
+ from pathlib import Path
22
+ from typing import Any
23
+
24
+
25
+ def log_params(config: Any, *, synchronous: bool | None = None) -> None:
26
+ """Log the parameters from the given configuration object.
27
+
28
+ This method logs the parameters from the provided configuration object
29
+ using MLflow. It iterates over the parameters and logs them using the
30
+ `mlflow.log_param` method.
31
+
32
+ Args:
33
+ config (Any): The configuration object to log the parameters from.
34
+ synchronous (bool | None): Whether to log the parameters synchronously.
35
+ Defaults to None.
36
+
37
+ """
38
+ for key, value in iter_params(config):
39
+ mlflow.log_param(key, value, synchronous=synchronous)
40
+
41
+
42
+ def log_text(from_dir: Path, pattern: str = "*.log") -> None:
43
+ """Log text files in the given directory as artifacts.
44
+
45
+ Append the text files to the existing text file in the artifact directory.
46
+
47
+ Args:
48
+ from_dir (Path): The directory to find the logs in.
49
+ pattern (str): The pattern to match the logs.
50
+
51
+ """
52
+ artifact_dir = get_artifact_dir()
53
+
54
+ for file in from_dir.glob(pattern):
55
+ if not file.is_file():
56
+ continue
57
+
58
+ file_artifact = artifact_dir / file.name
59
+ if file_artifact.exists():
60
+ text = file_artifact.read_text()
61
+ if not text.endswith("\n"):
62
+ text += "\n"
63
+ else:
64
+ text = ""
65
+
66
+ text += file.read_text()
67
+ mlflow.log_text(text, file.name)
68
+
69
+
70
+ def list_run_paths(
71
+ experiment_names: str | list[str] | None = None,
72
+ *other: str,
73
+ ) -> list[Path]:
74
+ """List all run paths for the specified experiments.
75
+
76
+ This function retrieves all run paths for the given list of experiment names.
77
+ If no experiment names are provided (None), the function will search all runs
78
+ for all experiments except the "Default" experiment.
79
+
80
+ Args:
81
+ experiment_names (list[str] | None): List of experiment names to search
82
+ for runs. If None is provided, the function will search all runs
83
+ for all experiments except the "Default" experiment.
84
+ *other (str): The parts of the run directory to join.
85
+
86
+ Returns:
87
+ list[Path]: A list of run paths for the specified experiments.
88
+
89
+ """
90
+ if isinstance(experiment_names, str):
91
+ experiment_names = [experiment_names]
92
+
93
+ elif experiment_names is None:
94
+ experiments = mlflow.search_experiments()
95
+ experiment_names = [e.name for e in experiments if e.name != "Default"]
96
+
97
+ run_paths: list[Path] = []
98
+
99
+ for name in experiment_names:
100
+ if experiment := mlflow.get_experiment_by_name(name):
101
+ uri = experiment.artifact_location
102
+
103
+ if isinstance(uri, str):
104
+ path = file_uri_to_path(uri)
105
+ run_paths.extend(p for p in path.iterdir() if p.is_dir())
106
+
107
+ if other:
108
+ return [p.joinpath(*other) for p in run_paths]
109
+
110
+ return run_paths
111
+
112
+
113
+ def list_run_ids(experiment_names: str | list[str] | None = None) -> list[str]:
114
+ """List all run IDs for the specified experiments.
115
+
116
+ This function retrieves all runs for the given list of experiment names.
117
+ If no experiment names are provided (None), the function will search all
118
+ runs for all experiments except the "Default" experiment.
119
+
120
+ Args:
121
+ experiment_names (list[str] | None): List of experiment names to search
122
+ for runs. If None is provided, the function will search all runs
123
+ for all experiments except the "Default" experiment.
124
+
125
+ Returns:
126
+ list[str]: A list of run IDs for the specified experiments.
127
+
128
+ """
129
+ return [run_path.stem for run_path in list_run_paths(experiment_names)]
130
+
131
+
132
+ def list_runs(
133
+ experiment_names: str | list[str] | None = None,
134
+ n_jobs: int = 0,
135
+ ) -> RunCollection:
136
+ """List all runs for the specified experiments.
137
+
138
+ This function retrieves all runs for the given list of experiment names.
139
+ If no experiment names are provided (None), the function will search all runs
140
+ for all experiments except the "Default" experiment.
141
+ The function returns the results as a `RunCollection` object.
142
+
143
+ Note:
144
+ The returned runs are sorted by their start time in ascending order.
145
+
146
+ Args:
147
+ experiment_names (list[str] | None): List of experiment names to search
148
+ for runs. If None is provided, the function will search all runs
149
+ for all experiments except the "Default" experiment.
150
+ n_jobs (int): The number of jobs to retrieve runs in parallel.
151
+
152
+ Returns:
153
+ RunCollection: A `RunCollection` instance containing the runs for the
154
+ specified experiments.
155
+
156
+ """
157
+ run_ids = list_run_ids(experiment_names)
158
+
159
+ if n_jobs == 0:
160
+ runs = [mlflow.get_run(run_id) for run_id in run_ids]
161
+
162
+ else:
163
+ it = (joblib.delayed(mlflow.get_run)(run_id) for run_id in run_ids)
164
+ runs = joblib.Parallel(n_jobs, backend="threading")(it)
165
+
166
+ runs = sorted(runs, key=lambda run: run.info.start_time) # type: ignore
167
+ return RunCollection(runs) # type: ignore