hydraflow 0.7.5__tar.gz → 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. {hydraflow-0.7.5 → hydraflow-0.9.0}/PKG-INFO +18 -19
  2. {hydraflow-0.7.5 → hydraflow-0.9.0}/README.md +16 -18
  3. {hydraflow-0.7.5 → hydraflow-0.9.0}/apps/quickstart.py +5 -2
  4. {hydraflow-0.7.5 → hydraflow-0.9.0}/docs/usage/quickstart.md +1 -35
  5. {hydraflow-0.7.5 → hydraflow-0.9.0}/pyproject.toml +22 -15
  6. hydraflow-0.9.0/src/hydraflow/__init__.py +29 -0
  7. hydraflow-0.9.0/src/hydraflow/cli.py +67 -0
  8. {hydraflow-0.7.5/src/hydraflow → hydraflow-0.9.0/src/hydraflow/core}/config.py +10 -27
  9. {hydraflow-0.7.5/src/hydraflow → hydraflow-0.9.0/src/hydraflow/core}/context.py +8 -50
  10. hydraflow-0.7.5/src/hydraflow/utils.py → hydraflow-0.9.0/src/hydraflow/core/io.py +19 -28
  11. hydraflow-0.9.0/src/hydraflow/core/main.py +164 -0
  12. hydraflow-0.9.0/src/hydraflow/core/mlflow.py +168 -0
  13. {hydraflow-0.7.5/src/hydraflow → hydraflow-0.9.0/src/hydraflow/core}/param.py +2 -2
  14. {hydraflow-0.7.5/src/hydraflow → hydraflow-0.9.0/src/hydraflow/entities}/run_collection.py +18 -163
  15. {hydraflow-0.7.5/src/hydraflow → hydraflow-0.9.0/src/hydraflow/entities}/run_data.py +5 -3
  16. {hydraflow-0.7.5/src/hydraflow → hydraflow-0.9.0/src/hydraflow/entities}/run_info.py +2 -2
  17. hydraflow-0.9.0/src/hydraflow/executor/conf.py +23 -0
  18. hydraflow-0.9.0/src/hydraflow/executor/io.py +34 -0
  19. hydraflow-0.9.0/src/hydraflow/executor/job.py +152 -0
  20. hydraflow-0.9.0/src/hydraflow/executor/parser.py +397 -0
  21. hydraflow-0.9.0/tests/cli/app.py +24 -0
  22. hydraflow-0.9.0/tests/cli/conftest.py +12 -0
  23. hydraflow-0.9.0/tests/cli/hydraflow.yaml +13 -0
  24. hydraflow-0.9.0/tests/cli/test_run.py +33 -0
  25. hydraflow-0.9.0/tests/cli/test_setup.py +7 -0
  26. hydraflow-0.9.0/tests/cli/test_show.py +23 -0
  27. hydraflow-0.9.0/tests/conftest.py +51 -0
  28. hydraflow-0.9.0/tests/core/config/test_config.py +54 -0
  29. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/config/test_params.py +11 -27
  30. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/context/chdir.py +5 -2
  31. hydraflow-0.7.5/tests/context/logging.py → hydraflow-0.9.0/tests/core/context/log_run.py +11 -8
  32. hydraflow-0.7.5/tests/context/context.py → hydraflow-0.9.0/tests/core/context/start_run.py +7 -13
  33. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/context/test_chdir.py +6 -3
  34. hydraflow-0.9.0/tests/core/context/test_log_run.py +58 -0
  35. hydraflow-0.9.0/tests/core/context/test_start_run.py +34 -0
  36. hydraflow-0.9.0/tests/core/io/__init__.py +0 -0
  37. hydraflow-0.7.5/tests/utils/utils.py → hydraflow-0.9.0/tests/core/io/hydra_dir.py +4 -5
  38. hydraflow-0.7.5/tests/utils/test_utils.py → hydraflow-0.9.0/tests/core/io/test_hydra_dir.py +10 -28
  39. {hydraflow-0.7.5/tests/utils → hydraflow-0.9.0/tests/core/io}/test_run.py +5 -5
  40. hydraflow-0.9.0/tests/core/main/__init__.py +0 -0
  41. hydraflow-0.7.5/tests/main/base.py → hydraflow-0.9.0/tests/core/main/default.py +2 -1
  42. hydraflow-0.9.0/tests/core/main/match_overrides.py +24 -0
  43. hydraflow-0.7.5/tests/main/restart.py → hydraflow-0.9.0/tests/core/main/rerun_finished.py +9 -2
  44. hydraflow-0.7.5/tests/main/test_base.py → hydraflow-0.9.0/tests/core/main/test_default.py +14 -9
  45. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/main/test_force_new_run.py +6 -3
  46. hydraflow-0.9.0/tests/core/main/test_match_overrides.py +25 -0
  47. hydraflow-0.7.5/tests/main/test_restart.py → hydraflow-0.9.0/tests/core/main/test_rerun_finished.py +6 -3
  48. hydraflow-0.7.5/tests/main/test_skip.py → hydraflow-0.9.0/tests/core/main/test_skip_finished.py +8 -6
  49. hydraflow-0.9.0/tests/core/param/__init__.py +0 -0
  50. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/param/params.py +5 -2
  51. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/param/test_param.py +6 -6
  52. hydraflow-0.9.0/tests/core/param/test_params.py +49 -0
  53. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/test_mlflow.py +10 -28
  54. hydraflow-0.9.0/tests/entities/__init__.py +0 -0
  55. {hydraflow-0.7.5/tests/run → hydraflow-0.9.0/tests/entities}/filter.py +6 -3
  56. {hydraflow-0.7.5/tests/run → hydraflow-0.9.0/tests/entities}/test_collection.py +31 -74
  57. {hydraflow-0.7.5/tests/run → hydraflow-0.9.0/tests/entities}/test_data.py +3 -3
  58. hydraflow-0.9.0/tests/entities/test_filter.py +44 -0
  59. {hydraflow-0.7.5/tests/run → hydraflow-0.9.0/tests/entities}/test_info.py +3 -3
  60. hydraflow-0.9.0/tests/entities/test_values.py +37 -0
  61. hydraflow-0.7.5/tests/run/run.py → hydraflow-0.9.0/tests/entities/values.py +7 -9
  62. hydraflow-0.9.0/tests/executor/__init__.py +0 -0
  63. hydraflow-0.9.0/tests/executor/conftest.py +30 -0
  64. hydraflow-0.9.0/tests/executor/echo.py +17 -0
  65. hydraflow-0.9.0/tests/executor/test_args.py +19 -0
  66. hydraflow-0.9.0/tests/executor/test_conf.py +34 -0
  67. hydraflow-0.9.0/tests/executor/test_io.py +18 -0
  68. hydraflow-0.9.0/tests/executor/test_job.py +127 -0
  69. hydraflow-0.9.0/tests/executor/test_parser.py +220 -0
  70. hydraflow-0.7.5/hydraflow.yaml +0 -5
  71. hydraflow-0.7.5/src/hydraflow/__init__.py +0 -44
  72. hydraflow-0.7.5/src/hydraflow/cli.py +0 -75
  73. hydraflow-0.7.5/src/hydraflow/main.py +0 -54
  74. hydraflow-0.7.5/src/hydraflow/mlflow.py +0 -280
  75. hydraflow-0.7.5/tests/cli/conftest.py +0 -9
  76. hydraflow-0.7.5/tests/cli/test_run.py +0 -18
  77. hydraflow-0.7.5/tests/cli/test_show.py +0 -52
  78. hydraflow-0.7.5/tests/config/overrides.py +0 -32
  79. hydraflow-0.7.5/tests/config/test_config.py +0 -29
  80. hydraflow-0.7.5/tests/config/test_overrides.py +0 -25
  81. hydraflow-0.7.5/tests/conftest.py +0 -81
  82. hydraflow-0.7.5/tests/context/rerun.py +0 -40
  83. hydraflow-0.7.5/tests/context/test_context.py +0 -23
  84. hydraflow-0.7.5/tests/context/test_logging.py +0 -51
  85. hydraflow-0.7.5/tests/context/test_rerun.py +0 -31
  86. hydraflow-0.7.5/tests/param/test_params.py +0 -35
  87. hydraflow-0.7.5/tests/run/test_filter.py +0 -19
  88. hydraflow-0.7.5/tests/run/test_run.py +0 -54
  89. {hydraflow-0.7.5 → hydraflow-0.9.0}/.devcontainer/devcontainer.json +0 -0
  90. {hydraflow-0.7.5 → hydraflow-0.9.0}/.devcontainer/postCreate.sh +0 -0
  91. {hydraflow-0.7.5 → hydraflow-0.9.0}/.devcontainer/starship.toml +0 -0
  92. {hydraflow-0.7.5 → hydraflow-0.9.0}/.gitattributes +0 -0
  93. {hydraflow-0.7.5 → hydraflow-0.9.0}/.github/workflows/ci.yaml +0 -0
  94. {hydraflow-0.7.5 → hydraflow-0.9.0}/.github/workflows/docs.yaml +0 -0
  95. {hydraflow-0.7.5 → hydraflow-0.9.0}/.gitignore +0 -0
  96. {hydraflow-0.7.5 → hydraflow-0.9.0}/LICENSE +0 -0
  97. {hydraflow-0.7.5 → hydraflow-0.9.0}/docs/index.md +0 -0
  98. {hydraflow-0.7.5 → hydraflow-0.9.0}/mkdocs.yaml +0 -0
  99. {hydraflow-0.7.5/tests → hydraflow-0.9.0/src/hydraflow/core}/__init__.py +0 -0
  100. {hydraflow-0.7.5/tests/cli → hydraflow-0.9.0/src/hydraflow/entities}/__init__.py +0 -0
  101. {hydraflow-0.7.5/tests/config → hydraflow-0.9.0/src/hydraflow/executor}/__init__.py +0 -0
  102. {hydraflow-0.7.5 → hydraflow-0.9.0}/src/hydraflow/py.typed +0 -0
  103. {hydraflow-0.7.5/tests/context → hydraflow-0.9.0/tests}/__init__.py +0 -0
  104. {hydraflow-0.7.5/tests/main → hydraflow-0.9.0/tests/cli}/__init__.py +0 -0
  105. {hydraflow-0.7.5 → hydraflow-0.9.0}/tests/cli/test_version.py +0 -0
  106. {hydraflow-0.7.5/tests/param → hydraflow-0.9.0/tests/core}/__init__.py +0 -0
  107. {hydraflow-0.7.5/tests/run → hydraflow-0.9.0/tests/core/config}/__init__.py +0 -0
  108. {hydraflow-0.7.5/tests/utils → hydraflow-0.9.0/tests/core/context}/__init__.py +0 -0
  109. {hydraflow-0.7.5/tests → hydraflow-0.9.0/tests/core}/main/force_new_run.py +0 -0
  110. /hydraflow-0.7.5/tests/main/skip.py → /hydraflow-0.9.0/tests/core/main/skip_finished.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: hydraflow
3
- Version: 0.7.5
3
+ Version: 0.9.0
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://daizutabi.github.io/hydraflow/
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -41,6 +41,7 @@ Requires-Dist: mlflow>=2.15
41
41
  Requires-Dist: omegaconf
42
42
  Requires-Dist: rich
43
43
  Requires-Dist: typer
44
+ Requires-Dist: ulid
44
45
  Description-Content-Type: text/markdown
45
46
 
46
47
  # Hydraflow
@@ -93,31 +94,29 @@ pip install hydraflow
93
94
  Here is a simple example to get you started with Hydraflow:
94
95
 
95
96
  ```python
96
- import hydra
97
- import hydraflow
98
- import mlflow
97
+ from __future__ import annotations
98
+
99
99
  from dataclasses import dataclass
100
- from hydra.core.config_store import ConfigStore
101
100
  from pathlib import Path
101
+ from typing import TYPE_CHECKING
102
102
 
103
- @dataclass
104
- class MySQLConfig:
105
- host: str = "localhost"
106
- port: int = 3306
103
+ import hydraflow
107
104
 
108
- cs = ConfigStore.instance()
109
- cs.store(name="config", node=MySQLConfig)
105
+ if TYPE_CHECKING:
106
+ from mlflow.entities import Run
107
+
108
+
109
+ @dataclass
110
+ class Config:
111
+ count: int = 1
112
+ name: str = "a"
110
113
 
111
- @hydra.main(version_base=None, config_name="config")
112
- def my_app(cfg: MySQLConfig) -> None:
113
- # Set experiment by Hydra job name.
114
- hydraflow.set_experiment()
115
114
 
116
- # Automatically log Hydra config as params.
117
- with hydraflow.start_run(cfg):
118
- # Your app code below.
115
+ @hydraflow.main(Config)
116
+ def app(run: Run, cfg: Config):
117
+ """Your app code here."""
119
118
 
120
119
 
121
120
  if __name__ == "__main__":
122
- my_app()
121
+ app()
123
122
  ```
@@ -48,31 +48,29 @@ pip install hydraflow
48
48
  Here is a simple example to get you started with Hydraflow:
49
49
 
50
50
  ```python
51
- import hydra
52
- import hydraflow
53
- import mlflow
51
+ from __future__ import annotations
52
+
54
53
  from dataclasses import dataclass
55
- from hydra.core.config_store import ConfigStore
56
54
  from pathlib import Path
55
+ from typing import TYPE_CHECKING
57
56
 
58
- @dataclass
59
- class MySQLConfig:
60
- host: str = "localhost"
61
- port: int = 3306
57
+ import hydraflow
62
58
 
63
- cs = ConfigStore.instance()
64
- cs.store(name="config", node=MySQLConfig)
59
+ if TYPE_CHECKING:
60
+ from mlflow.entities import Run
61
+
62
+
63
+ @dataclass
64
+ class Config:
65
+ count: int = 1
66
+ name: str = "a"
65
67
 
66
- @hydra.main(version_base=None, config_name="config")
67
- def my_app(cfg: MySQLConfig) -> None:
68
- # Set experiment by Hydra job name.
69
- hydraflow.set_experiment()
70
68
 
71
- # Automatically log Hydra config as params.
72
- with hydraflow.start_run(cfg):
73
- # Your app code below.
69
+ @hydraflow.main(Config)
70
+ def app(run: Run, cfg: Config):
71
+ """Your app code here."""
74
72
 
75
73
 
76
74
  if __name__ == "__main__":
77
- my_app()
75
+ app()
78
76
  ```
@@ -2,7 +2,9 @@ import logging
2
2
  from dataclasses import dataclass
3
3
 
4
4
  import hydra
5
+ import mlflow
5
6
  from hydra.core.config_store import ConfigStore
7
+ from hydra.core.hydra_config import HydraConfig
6
8
 
7
9
  import hydraflow
8
10
 
@@ -19,9 +21,10 @@ cs = ConfigStore.instance()
19
21
  cs.store(name="config", node=Config)
20
22
 
21
23
 
22
- @hydra.main(version_base=None, config_name="config")
24
+ @hydra.main(config_name="config", version_base=None)
23
25
  def app(cfg: Config) -> None:
24
- hydraflow.set_experiment()
26
+ hc = HydraConfig.get()
27
+ mlflow.set_experiment(hc.job.name)
25
28
 
26
29
  with hydraflow.start_run(cfg):
27
30
  log.info(f"{cfg.width=}, {cfg.height=}")
@@ -12,16 +12,6 @@ There are two main steps to using Hydraflow:
12
12
  --8<-- "apps/quickstart.py"
13
13
  ```
14
14
 
15
- ### Set the MLflow experiment
16
-
17
- [`hydraflow.set_experiment`][] sets the MLflow experiment using the Hydra job name.
18
- Optionally, it can also set the tracking URI with `uri` argument.
19
- For example,
20
-
21
- ```python
22
- hydraflow.set_experiment(uri="sqlite:///mlruns.db")
23
- ```
24
-
25
15
  ### Start a new MLflow run
26
16
 
27
17
  [`hydraflow.start_run`][] starts a new MLflow run that logs the Hydra configuration.
@@ -64,10 +54,8 @@ $ python apps/quickstart.py -m width=400,600 height=100,200,300
64
54
  ### Run collection
65
55
 
66
56
  ```pycon exec="1" source="console" session="quickstart"
67
- >>> import mlflow
68
- >>> mlflow.set_experiment("quickstart")
69
57
  >>> import hydraflow
70
- >>> rc = hydraflow.list_runs()
58
+ >>> rc = hydraflow.list_runs("quickstart")
71
59
  >>> print(rc)
72
60
  ```
73
61
 
@@ -107,28 +95,6 @@ $ python apps/quickstart.py -m width=400,600 height=100,200,300
107
95
  >>> print(filtered)
108
96
  ```
109
97
 
110
- ```pycon exec="1" source="console" session="quickstart"
111
- >>> run = rc.find(height=100)
112
- >>> print(run.data.params)
113
- ```
114
-
115
- ```pycon exec="1" source="console" session="quickstart"
116
- >>> run = rc.find_last(height=100)
117
- >>> print(run.data.params)
118
- ```
119
-
120
- ### Map runs
121
-
122
- ```pycon exec="1" source="console" session="quickstart"
123
- >>> params = rc.map(lambda x: x.data.params)
124
- >>> for p in params:
125
- ... print(p)
126
- ```
127
-
128
- ```pycon exec="1" source="console" session="quickstart"
129
- >>> list(rc.map_id(print))
130
- ```
131
-
132
98
  ### Group runs
133
99
 
134
100
  ```pycon exec="1" source="console" session="quickstart"
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.7.5"
7
+ version = "0.9.0"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -19,7 +19,14 @@ classifiers = [
19
19
  "Programming Language :: Python :: 3.13",
20
20
  ]
21
21
  requires-python = ">=3.10"
22
- dependencies = ["hydra-core>=1.3", "mlflow>=2.15", "omegaconf", "rich", "typer"]
22
+ dependencies = [
23
+ "hydra-core>=1.3",
24
+ "mlflow>=2.15",
25
+ "omegaconf",
26
+ "rich",
27
+ "typer",
28
+ "ulid",
29
+ ]
23
30
 
24
31
  [project.urls]
25
32
  Documentation = "https://daizutabi.github.io/hydraflow/"
@@ -44,6 +51,7 @@ addopts = [
44
51
  "--cov=hydraflow",
45
52
  "--cov-report=lcov:lcov.info",
46
53
  "--dist=loadgroup",
54
+ "--doctest-modules",
47
55
  "-n8",
48
56
  ]
49
57
  filterwarnings = [
@@ -67,29 +75,28 @@ ignore = [
67
75
  "ANN003",
68
76
  "ANN401",
69
77
  "B904",
78
+ "D104",
70
79
  "D105",
71
80
  "D107",
72
81
  "D203",
73
82
  "D213",
74
83
  "EM101",
84
+ "FBT001",
85
+ "FBT002",
75
86
  "PGH003",
87
+ "PLR0911",
88
+ "PLR0913",
76
89
  "PLR1704",
90
+ "PLR2004",
91
+ "S603",
92
+ "SIM102",
93
+ "SIM108",
77
94
  "TRY003",
78
95
  ]
79
96
 
80
97
  [tool.ruff.lint.per-file-ignores]
81
- "tests/*" = [
82
- "A001",
83
- "ANN",
84
- "ARG",
85
- "D",
86
- "FBT",
87
- "PLR",
88
- "PT",
89
- "S",
90
- "SIM108",
91
- "SLF",
92
- ]
93
98
  "apps/*.py" = ["D", "G", "INP"]
94
- "src/hydraflow/main.py" = ["ANN201", "D401", "PLR0913"]
95
99
  "src/hydraflow/cli.py" = ["ANN", "D"]
100
+ "src/hydraflow/core/main.py" = ["ANN201", "D401"]
101
+ "src/hydraflow/executor/conf.py" = ["ANN", "D"]
102
+ "tests/*" = ["A001", "ANN", "ARG", "D", "FBT", "PD", "PLR", "PT", "S", "SLF"]
@@ -0,0 +1,29 @@
1
+ """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
+
3
+ from hydraflow.core.context import chdir_artifact, log_run, start_run
4
+ from hydraflow.core.io import (
5
+ get_artifact_dir,
6
+ get_artifact_path,
7
+ get_hydra_output_dir,
8
+ load_config,
9
+ remove_run,
10
+ )
11
+ from hydraflow.core.main import main
12
+ from hydraflow.core.mlflow import list_run_ids, list_run_paths, list_runs
13
+ from hydraflow.entities.run_collection import RunCollection
14
+
15
+ __all__ = [
16
+ "RunCollection",
17
+ "chdir_artifact",
18
+ "get_artifact_dir",
19
+ "get_artifact_path",
20
+ "get_hydra_output_dir",
21
+ "list_run_ids",
22
+ "list_run_paths",
23
+ "list_runs",
24
+ "load_config",
25
+ "log_run",
26
+ "main",
27
+ "remove_run",
28
+ "start_run",
29
+ ]
@@ -0,0 +1,67 @@
1
+ """Hydraflow CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Annotated
6
+
7
+ import typer
8
+ from rich.console import Console
9
+ from typer import Argument, Option
10
+
11
+ from hydraflow.executor.io import load_config
12
+
13
+ if TYPE_CHECKING:
14
+ from hydraflow.executor.job import Job
15
+
16
+ app = typer.Typer(add_completion=False)
17
+ console = Console()
18
+
19
+
20
+ def get_job(name: str) -> Job:
21
+ cfg = load_config()
22
+ job = cfg.jobs[name]
23
+
24
+ if not job.name:
25
+ job.name = name
26
+
27
+ return job
28
+
29
+
30
+ @app.command()
31
+ def run(
32
+ name: Annotated[str, Argument(help="Job name.", show_default=False)],
33
+ ) -> None:
34
+ """Run a job."""
35
+ import mlflow
36
+
37
+ from hydraflow.executor.job import multirun
38
+
39
+ job = get_job(name)
40
+ mlflow.set_experiment(job.name)
41
+ multirun(job)
42
+
43
+
44
+ @app.command()
45
+ def show(
46
+ name: Annotated[str, Argument(help="Job name.", show_default=False)],
47
+ ) -> None:
48
+ """Show a job."""
49
+ from hydraflow.executor.job import show
50
+
51
+ job = get_job(name)
52
+ show(job)
53
+
54
+
55
+ @app.callback(invoke_without_command=True)
56
+ def callback(
57
+ *,
58
+ version: Annotated[
59
+ bool,
60
+ Option("--version", help="Show the version and exit."),
61
+ ] = False,
62
+ ) -> None:
63
+ if version:
64
+ import importlib.metadata
65
+
66
+ typer.echo(f"hydraflow {importlib.metadata.version('hydraflow')}")
67
+ raise typer.Exit
@@ -6,35 +6,19 @@ from typing import TYPE_CHECKING
6
6
 
7
7
  from omegaconf import DictConfig, ListConfig, OmegaConf
8
8
 
9
- from hydraflow.utils import get_overrides
10
-
11
9
  if TYPE_CHECKING:
12
10
  from collections.abc import Iterator
13
11
  from typing import Any
14
12
 
15
13
 
16
- def collect_params(config: object) -> dict[str, Any]:
17
- """Iterate over parameters and collect them into a dictionary.
18
-
19
- Args:
20
- config (object): The configuration object to iterate over.
21
- prefix (str): The prefix to prepend to the parameter keys.
22
-
23
- Returns:
24
- dict[str, Any]: A dictionary of collected parameters.
25
-
26
- """
27
- return dict(iter_params(config))
28
-
29
-
30
- def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
14
+ def iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
31
15
  """Recursively iterate over the parameters in the given configuration object.
32
16
 
33
17
  This function traverses the configuration object and yields key-value pairs
34
18
  representing the parameters. The keys are prefixed with the provided prefix.
35
19
 
36
20
  Args:
37
- config (object): The configuration object to iterate over. This can be a
21
+ config (Any): The configuration object to iterate over. This can be a
38
22
  dictionary, list, DictConfig, or ListConfig.
39
23
  prefix (str): The prefix to prepend to the parameter keys.
40
24
  Defaults to an empty string.
@@ -50,7 +34,7 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
50
34
  config = _from_dotlist(config)
51
35
 
52
36
  if not isinstance(config, DictConfig | ListConfig):
53
- config = OmegaConf.create(config) # type: ignore
37
+ config = OmegaConf.create(config)
54
38
 
55
39
  yield from _iter_params(config, prefix)
56
40
 
@@ -65,7 +49,7 @@ def _from_dotlist(config: list[str]) -> dict[str, str]:
65
49
  return result
66
50
 
67
51
 
68
- def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
52
+ def _iter_params(config: Any, prefix: str = "") -> Iterator[tuple[str, Any]]:
69
53
  if isinstance(config, DictConfig):
70
54
  for key, value in config.items():
71
55
  if _is_param(value):
@@ -83,12 +67,12 @@ def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
83
67
  yield from _iter_params(value, f"{prefix}{index}.")
84
68
 
85
69
 
86
- def _is_param(value: object) -> bool:
70
+ def _is_param(value: Any) -> bool:
87
71
  """Check if the given value is a parameter."""
88
72
  if isinstance(value, DictConfig):
89
73
  return False
90
74
 
91
- if isinstance(value, ListConfig): # noqa: SIM102
75
+ if isinstance(value, ListConfig):
92
76
  if any(isinstance(v, DictConfig | ListConfig) for v in value):
93
77
  return False
94
78
 
@@ -103,14 +87,14 @@ def _convert(value: Any) -> Any:
103
87
  return value
104
88
 
105
89
 
106
- def select_config(config: object, names: list[str]) -> dict[str, Any]:
90
+ def select_config(config: Any, names: list[str]) -> dict[str, Any]:
107
91
  """Select the given parameters from the configuration object.
108
92
 
109
93
  This function selects the given parameters from the configuration object
110
94
  and returns a new configuration object containing only the selected parameters.
111
95
 
112
96
  Args:
113
- config (object): The configuration object to select parameters from.
97
+ config (Any): The configuration object to select parameters from.
114
98
  names (list[str]): The names of the parameters to select.
115
99
 
116
100
  Returns:
@@ -120,7 +104,7 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
120
104
  if not isinstance(config, DictConfig):
121
105
  config = OmegaConf.structured(config)
122
106
 
123
- return {name: _get(config, name) for name in names} # type: ignore
107
+ return {name: _get(config, name) for name in names}
124
108
 
125
109
 
126
110
  def _get(config: DictConfig, name: str) -> Any:
@@ -132,8 +116,7 @@ def _get(config: DictConfig, name: str) -> Any:
132
116
  return _get(config.get(prefix), name)
133
117
 
134
118
 
135
- def select_overrides(config: object) -> dict[str, Any]:
119
+ def select_overrides(config: object, overrides: list[str]) -> dict[str, Any]:
136
120
  """Select the given overrides from the configuration object."""
137
- overrides = get_overrides()
138
121
  names = [override.split("=")[0].strip() for override in overrides]
139
122
  return select_config(config, names)
@@ -12,8 +12,9 @@ import mlflow
12
12
  import mlflow.artifacts
13
13
  from hydra.core.hydra_config import HydraConfig
14
14
 
15
- from hydraflow.mlflow import log_params
16
- from hydraflow.utils import get_artifact_dir
15
+ from hydraflow.core.io import get_artifact_dir
16
+
17
+ from .mlflow import log_params, log_text
17
18
 
18
19
  if TYPE_CHECKING:
19
20
  from collections.abc import Iterator
@@ -55,11 +56,11 @@ def log_run(
55
56
  log_params(config, synchronous=synchronous)
56
57
 
57
58
  hc = HydraConfig.get()
58
- output_dir = Path(hc.runtime.output_dir)
59
+ hydra_dir = Path(hc.runtime.output_dir)
59
60
 
60
61
  # Save '.hydra' config directory.
61
- output_subdir = output_dir / (hc.output_subdir or "")
62
- mlflow.log_artifacts(output_subdir.as_posix(), hc.output_subdir)
62
+ hydra_subdir = hydra_dir / (hc.output_subdir or "")
63
+ mlflow.log_artifacts(hydra_subdir.as_posix(), hc.output_subdir)
63
64
 
64
65
  try:
65
66
  yield
@@ -70,43 +71,14 @@ def log_run(
70
71
  raise
71
72
 
72
73
  finally:
73
- log_text(output_dir)
74
-
75
-
76
- def log_text(directory: Path, pattern: str = "*.log") -> None:
77
- """Log text files in the given directory as artifacts.
78
-
79
- Append the text files to the existing text file in the artifact directory.
80
-
81
- Args:
82
- directory (Path): The directory to find the logs in.
83
- pattern (str): The pattern to match the logs.
84
-
85
- """
86
- artifact_dir = get_artifact_dir()
87
-
88
- for file in directory.glob(pattern):
89
- if not file.is_file():
90
- continue
91
-
92
- file_artifact = artifact_dir / file.name
93
- if file_artifact.exists():
94
- text = file_artifact.read_text()
95
- if not text.endswith("\n"):
96
- text += "\n"
97
- else:
98
- text = ""
99
-
100
- text += file.read_text()
101
- mlflow.log_text(text, file.name)
74
+ log_text(hydra_dir)
102
75
 
103
76
 
104
77
  @contextmanager
105
- def start_run( # noqa: PLR0913
78
+ def start_run(
106
79
  config: object,
107
80
  *,
108
81
  chdir: bool = False,
109
- run: Run | None = None,
110
82
  run_id: str | None = None,
111
83
  experiment_id: str | None = None,
112
84
  run_name: str | None = None,
@@ -126,7 +98,6 @@ def start_run( # noqa: PLR0913
126
98
  config (object): The configuration object to log parameters from.
127
99
  chdir (bool): Whether to change the current working directory to the
128
100
  artifact directory of the current run. Defaults to False.
129
- run (Run | None): The existing run. Defaults to None.
130
101
  run_id (str | None): The existing run ID. Defaults to None.
131
102
  experiment_id (str | None): The experiment ID. Defaults to None.
132
103
  run_name (str | None): The name of the run. Defaults to None.
@@ -142,20 +113,7 @@ def start_run( # noqa: PLR0913
142
113
  Yields:
143
114
  Run: An MLflow Run object representing the started run.
144
115
 
145
- Example:
146
- with start_run(config) as run:
147
- # Perform operations within the MLflow run context
148
- pass
149
-
150
- See Also:
151
- - `mlflow.start_run`: The MLflow function to start a run directly.
152
- - `log_run`: A context manager to log parameters and manage the MLflow
153
- run context.
154
-
155
116
  """
156
- if run:
157
- run_id = run.info.run_id
158
-
159
117
  with (
160
118
  mlflow.start_run(
161
119
  run_id=run_id,
@@ -12,46 +12,42 @@ import mlflow
12
12
  import mlflow.artifacts
13
13
  from hydra.core.hydra_config import HydraConfig
14
14
  from mlflow.entities import Run
15
- from omegaconf import DictConfig, OmegaConf
15
+ from omegaconf import DictConfig, ListConfig, OmegaConf
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from collections.abc import Iterable
19
19
 
20
20
 
21
- def get_artifact_dir(run: Run | None = None, uri: str | None = None) -> Path:
21
+ def file_uri_to_path(uri: str) -> Path:
22
+ """Convert a file URI to a local path."""
23
+ if not uri.startswith("file:"):
24
+ return Path(uri)
25
+
26
+ path = urllib.parse.urlparse(uri).path
27
+ return Path(urllib.request.url2pathname(path)) # for Windows
28
+
29
+
30
+ def get_artifact_dir(run: Run | None = None) -> Path:
22
31
  """Retrieve the artifact directory for the given run.
23
32
 
24
33
  This function uses MLflow to get the artifact directory for the given run.
25
34
 
26
35
  Args:
27
36
  run (Run | None): The run object. Defaults to None.
28
- uri (str | None): The URI of the artifact. Defaults to None.
29
37
 
30
38
  Returns:
31
39
  The local path to the directory where the artifacts are downloaded.
32
40
 
33
41
  """
34
- if run is not None and uri is not None:
35
- raise ValueError("Cannot provide both run and uri")
36
-
37
- if run is None and uri is None:
42
+ if run is None:
38
43
  uri = mlflow.get_artifact_uri()
39
- elif run:
44
+ else:
40
45
  uri = run.info.artifact_uri
41
46
 
42
47
  if not isinstance(uri, str):
43
48
  raise NotImplementedError
44
49
 
45
- if uri.startswith("file:"):
46
- return file_uri_to_path(uri)
47
-
48
- return Path(uri)
49
-
50
-
51
- def file_uri_to_path(uri: str) -> Path:
52
- """Convert a file URI to a local path."""
53
- path = urllib.parse.urlparse(uri).path
54
- return Path(urllib.request.url2pathname(path)) # for Windows
50
+ return file_uri_to_path(uri)
55
51
 
56
52
 
57
53
  def get_artifact_path(run: Run | None, path: str) -> Path:
@@ -123,12 +119,7 @@ def load_config(run: Run) -> DictConfig:
123
119
  return OmegaConf.load(path) # type: ignore
124
120
 
125
121
 
126
- def get_overrides() -> list[str]:
127
- """Retrieve the overrides for the current run."""
128
- return list(HydraConfig.get().overrides.task) # ListConifg -> list
129
-
130
-
131
- def load_overrides(run: Run) -> list[str]:
122
+ def load_overrides(run: Run) -> ListConfig:
132
123
  """Load the overrides for a given run.
133
124
 
134
125
  This function loads the overrides for the provided Run instance
@@ -137,15 +128,15 @@ def load_overrides(run: Run) -> list[str]:
137
128
  `.hydra/overrides.yaml` is not found in the run's artifact directory.
138
129
 
139
130
  Args:
140
- run (Run): The Run instance for which to load the overrides.
131
+ run (Run): The Run instance for which to load the configuration.
141
132
 
142
133
  Returns:
143
- The loaded overrides as a list of strings. Returns an empty list
144
- if the overrides file is not found.
134
+ The loaded configuration as a DictConfig object. Returns an empty
135
+ DictConfig if the configuration file is not found.
145
136
 
146
137
  """
147
138
  path = get_artifact_dir(run) / ".hydra/overrides.yaml"
148
- return [str(x) for x in OmegaConf.load(path)]
139
+ return OmegaConf.load(path) # type: ignore
149
140
 
150
141
 
151
142
  def remove_run(run: Run | Iterable[Run]) -> None: