hydraflow 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hydraflow/__init__.py CHANGED
@@ -1,30 +1,22 @@
1
- from .context import Info, chdir_artifact, log_run, watch
2
- from .mlflow import set_experiment
1
+ from .context import chdir_artifact, log_run, start_run, watch
2
+ from .mlflow import get_artifact_dir, get_hydra_output_dir, set_experiment
3
3
  from .runs import (
4
- Run,
5
4
  RunCollection,
6
- filter_runs,
7
- get_param_dict,
8
- get_param_names,
9
- get_run,
10
5
  list_runs,
11
6
  load_config,
12
7
  search_runs,
13
8
  )
14
9
 
15
10
  __all__ = [
16
- "Info",
17
- "Run",
18
11
  "RunCollection",
19
12
  "chdir_artifact",
20
- "filter_runs",
21
- "get_param_dict",
22
- "get_param_names",
23
- "get_run",
13
+ "get_artifact_dir",
14
+ "get_hydra_output_dir",
24
15
  "list_runs",
25
16
  "load_config",
26
17
  "log_run",
27
18
  "search_runs",
28
19
  "set_experiment",
20
+ "start_run",
29
21
  "watch",
30
22
  ]
hydraflow/config.py CHANGED
@@ -22,9 +22,9 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
22
22
  representing the parameters. The keys are prefixed with the provided prefix.
23
23
 
24
24
  Args:
25
- config: The configuration object to iterate over. This can be a dictionary,
26
- list, DictConfig, or ListConfig.
27
- prefix: The prefix to prepend to the parameter keys.
25
+ config (object): The configuration object to iterate over. This can be a
26
+ dictionary, list, DictConfig, or ListConfig.
27
+ prefix (str): The prefix to prepend to the parameter keys.
28
28
  Defaults to an empty string.
29
29
 
30
30
  Yields:
hydraflow/context.py CHANGED
@@ -9,7 +9,6 @@ import logging
9
9
  import os
10
10
  import time
11
11
  from contextlib import contextmanager
12
- from dataclasses import dataclass
13
12
  from pathlib import Path
14
13
  from typing import TYPE_CHECKING
15
14
 
@@ -28,18 +27,12 @@ if TYPE_CHECKING:
28
27
  log = logging.getLogger(__name__)
29
28
 
30
29
 
31
- @dataclass
32
- class Info:
33
- output_dir: Path
34
- artifact_dir: Path
35
-
36
-
37
30
  @contextmanager
38
31
  def log_run(
39
32
  config: object,
40
33
  *,
41
34
  synchronous: bool | None = None,
42
- ) -> Iterator[Info]:
35
+ ) -> Iterator[None]:
43
36
  """
44
37
  Log the parameters from the given configuration object and manage the MLflow
45
38
  run context.
@@ -49,16 +42,15 @@ def log_run(
49
42
  are logged and the run is properly closed.
50
43
 
51
44
  Args:
52
- config: The configuration object to log the parameters from.
53
- synchronous: Whether to log the parameters synchronously.
45
+ config (object): The configuration object to log the parameters from.
46
+ synchronous (bool | None): Whether to log the parameters synchronously.
54
47
  Defaults to None.
55
48
 
56
49
  Yields:
57
- Info: An `Info` object containing the output directory and artifact directory
58
- paths.
50
+ None
59
51
 
60
52
  Example:
61
- with log_run(config) as info:
53
+ with log_run(config):
62
54
  # Perform operations within the MLflow run context
63
55
  pass
64
56
  """
@@ -66,7 +58,6 @@ def log_run(
66
58
 
67
59
  hc = HydraConfig.get()
68
60
  output_dir = Path(hc.runtime.output_dir)
69
- info = Info(output_dir, get_artifact_dir())
70
61
 
71
62
  # Save '.hydra' config directory first.
72
63
  output_subdir = output_dir / (hc.output_subdir or "")
@@ -78,7 +69,7 @@ def log_run(
78
69
 
79
70
  try:
80
71
  with watch(log_artifact, output_dir):
81
- yield info
72
+ yield
82
73
 
83
74
  except Exception as e:
84
75
  log.error(f"Error during log_run: {e}")
@@ -89,6 +80,64 @@ def log_run(
89
80
  mlflow.log_artifacts(output_dir.as_posix())
90
81
 
91
82
 
83
+ @contextmanager
84
+ def start_run(
85
+ config: object,
86
+ *,
87
+ run_id: str | None = None,
88
+ experiment_id: str | None = None,
89
+ run_name: str | None = None,
90
+ nested: bool = False,
91
+ parent_run_id: str | None = None,
92
+ tags: dict[str, str] | None = None,
93
+ description: str | None = None,
94
+ log_system_metrics: bool | None = None,
95
+ synchronous: bool | None = None,
96
+ ) -> Iterator[Run]:
97
+ """
98
+ Start an MLflow run and log parameters using the provided configuration object.
99
+
100
+ This context manager starts an MLflow run and logs parameters using the specified
101
+ configuration object. It ensures that the run is properly closed after completion.
102
+
103
+ Args:
104
+ config (object): The configuration object to log parameters from.
105
+ run_id (str | None): The existing run ID. Defaults to None.
106
+ experiment_id (str | None): The experiment ID. Defaults to None.
107
+ run_name (str | None): The name of the run. Defaults to None.
108
+ nested (bool): Whether to allow nested runs. Defaults to False.
109
+ parent_run_id (str | None): The parent run ID. Defaults to None.
110
+ tags (dict[str, str] | None): Tags to associate with the run. Defaults to None.
111
+ description (str | None): A description of the run. Defaults to None.
112
+ log_system_metrics (bool | None): Whether to log system metrics. Defaults to None.
113
+ synchronous (bool | None): Whether to log parameters synchronously. Defaults to None.
114
+
115
+ Yields:
116
+ Run: An MLflow Run object representing the started run.
117
+
118
+ Example:
119
+ with start_run(config) as run:
120
+ # Perform operations within the MLflow run context
121
+ pass
122
+
123
+ See Also:
124
+ `mlflow.start_run`: The MLflow function to start a run directly.
125
+ `log_run`: A context manager to log parameters and manage the MLflow run context.
126
+ """
127
+ with mlflow.start_run(
128
+ run_id=run_id,
129
+ experiment_id=experiment_id,
130
+ run_name=run_name,
131
+ nested=nested,
132
+ parent_run_id=parent_run_id,
133
+ tags=tags,
134
+ description=description,
135
+ log_system_metrics=log_system_metrics,
136
+ ) as run:
137
+ with log_run(config, synchronous=synchronous):
138
+ yield run
139
+
140
+
92
141
  @contextmanager
93
142
  def watch(
94
143
  func: Callable[[Path], None],
@@ -105,12 +154,12 @@ def watch(
105
154
  period or until the context is exited.
106
155
 
107
156
  Args:
108
- func: The function to call when a change is
157
+ func (Callable[[Path], None]): The function to call when a change is
109
158
  detected. It should accept a single argument of type `Path`,
110
159
  which is the path of the modified file.
111
- dir: The directory to watch. If not specified,
160
+ dir (Path | str): The directory to watch. If not specified,
112
161
  the current MLflow artifact URI is used. Defaults to "".
113
- timeout: The timeout period in seconds for the watcher
162
+ timeout (int): The timeout period in seconds for the watcher
114
163
  to run after the context is exited. Defaults to 60.
115
164
 
116
165
  Yields:
@@ -122,6 +171,8 @@ def watch(
122
171
  pass
123
172
  """
124
173
  dir = dir or get_artifact_dir()
174
+ if isinstance(dir, Path):
175
+ dir = dir.as_posix()
125
176
 
126
177
  handler = Handler(func)
127
178
  observer = Observer()
@@ -152,7 +203,7 @@ class Handler(FileSystemEventHandler):
152
203
  self.func = func
153
204
 
154
205
  def on_modified(self, event: FileModifiedEvent) -> None:
155
- file = Path(event.src_path)
206
+ file = Path(str(event.src_path))
156
207
  if file.is_file():
157
208
  self.func(file)
158
209
 
@@ -171,8 +222,8 @@ def chdir_artifact(
171
222
  to the original directory after the context is exited.
172
223
 
173
224
  Args:
174
- run: The run to get the artifact directory from.
175
- artifact_path: The artifact path.
225
+ run (Run): The run to get the artifact directory from.
226
+ artifact_path (str | None): The artifact path.
176
227
  """
177
228
  curdir = Path.cwd()
178
229
  path = mlflow.artifacts.download_artifacts(
hydraflow/mlflow.py CHANGED
@@ -6,14 +6,24 @@ configuration objects and set up experiments using MLflow.
6
6
  from __future__ import annotations
7
7
 
8
8
  from pathlib import Path
9
+ from typing import TYPE_CHECKING
9
10
 
10
11
  import mlflow
11
12
  from hydra.core.hydra_config import HydraConfig
13
+ from mlflow.tracking import artifact_utils
14
+ from omegaconf import OmegaConf
12
15
 
13
16
  from hydraflow.config import iter_params
14
17
 
18
+ if TYPE_CHECKING:
19
+ from mlflow.entities.experiment import Experiment
15
20
 
16
- def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -> None:
21
+
22
+ def set_experiment(
23
+ prefix: str = "",
24
+ suffix: str = "",
25
+ uri: str | Path | None = None,
26
+ ) -> Experiment:
17
27
  """
18
28
  Set the experiment name and tracking URI optionally.
19
29
 
@@ -22,16 +32,20 @@ def set_experiment(prefix: str = "", suffix: str = "", uri: str | None = None) -
22
32
  also set the tracking URI.
23
33
 
24
34
  Args:
25
- prefix: The prefix to prepend to the experiment name.
26
- suffix: The suffix to append to the experiment name.
27
- uri: The tracking URI to use.
35
+ prefix (str): The prefix to prepend to the experiment name.
36
+ suffix (str): The suffix to append to the experiment name.
37
+ uri (str | Path | None): The tracking URI to use. Defaults to None.
38
+
39
+ Returns:
40
+ Experiment: An instance of `mlflow.entities.Experiment` representing
41
+ the new active experiment.
28
42
  """
29
- if uri:
43
+ if uri is not None:
30
44
  mlflow.set_tracking_uri(uri)
31
45
 
32
46
  hc = HydraConfig.get()
33
47
  name = f"{prefix}{hc.job.name}{suffix}"
34
- mlflow.set_experiment(name)
48
+ return mlflow.set_experiment(name)
35
49
 
36
50
 
37
51
  def log_params(config: object, *, synchronous: bool | None = None) -> None:
@@ -43,15 +57,19 @@ def log_params(config: object, *, synchronous: bool | None = None) -> None:
43
57
  `mlflow.log_param` method.
44
58
 
45
59
  Args:
46
- config: The configuration object to log the parameters from.
47
- synchronous: Whether to log the parameters synchronously.
60
+ config (object): The configuration object to log the parameters from.
61
+ synchronous (bool | None): Whether to log the parameters synchronously.
48
62
  Defaults to None.
49
63
  """
50
64
  for key, value in iter_params(config):
51
65
  mlflow.log_param(key, value, synchronous=synchronous)
52
66
 
53
67
 
54
- def get_artifact_dir(artifact_path: str | None = None) -> Path:
68
+ def get_artifact_dir(
69
+ artifact_path: str | None = None,
70
+ *,
71
+ run_id: str | None = None,
72
+ ) -> Path:
55
73
  """
56
74
  Get the artifact directory for the given artifact path.
57
75
 
@@ -60,13 +78,47 @@ def get_artifact_dir(artifact_path: str | None = None) -> Path:
60
78
  the path to that directory.
61
79
 
62
80
  Args:
63
- artifact_path: The artifact path for which to get the directory.
64
- Defaults to None.
81
+ artifact_path (str | None): The artifact path for which to get the
82
+ directory. Defaults to None.
83
+ run_id (str | None): The run ID for which to get the artifact directory.
65
84
 
66
85
  Returns:
67
86
  The local path to the directory where the artifacts are downloaded.
68
87
  """
69
- uri = mlflow.get_artifact_uri(artifact_path)
88
+ if run_id is None:
89
+ uri = mlflow.get_artifact_uri(artifact_path)
90
+ else:
91
+ uri = artifact_utils.get_artifact_uri(run_id, artifact_path)
92
+
70
93
  dir = mlflow.artifacts.download_artifacts(artifact_uri=uri)
71
94
 
72
95
  return Path(dir)
96
+
97
+
98
+ def get_hydra_output_dir(*, run_id: str | None = None) -> Path:
99
+ if run_id is None:
100
+ hc = HydraConfig.get()
101
+ return Path(hc.runtime.output_dir)
102
+
103
+ path = get_artifact_dir(run_id=run_id) / ".hydra/hydra.yaml"
104
+
105
+ if path.exists():
106
+ hc = OmegaConf.load(path)
107
+ return Path(hc.hydra.runtime.output_dir)
108
+
109
+ raise FileNotFoundError
110
+
111
+
112
+ # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
113
+ # """
114
+ # Log the Hydra output directory.
115
+
116
+ # Args:
117
+ # run: The run object.
118
+
119
+ # Returns:
120
+ # None
121
+ # """
122
+ # output_dir = get_hydra_output_dir(run)
123
+ # run_id = run if isinstance(run, str) else run.info.run_id
124
+ # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
hydraflow/runs.py CHANGED
@@ -45,28 +45,38 @@ def search_runs(
45
45
  The returned runs are sorted by their start time in ascending order.
46
46
 
47
47
  Args:
48
- experiment_ids: List of experiment IDs. Search can work with experiment
49
- IDs or experiment names, but not both in the same call. Values
50
- other than ``None`` or ``[]`` will result in error if
48
+ experiment_ids (list[str] | None): List of experiment IDs. Search can
49
+ work with experiment IDs or experiment names, but not both in the
50
+ same call. Values other than ``None`` or ``[]`` will result in
51
+ error if ``experiment_names`` is also not ``None`` or ``[]``.
52
+ ``None`` will default to the active experiment if ``experiment_names``
53
+ is ``None`` or ``[]``.
54
+ experiment_ids (list[str] | None): List of experiment IDs. Search can
55
+ work with experiment IDs or experiment names, but not both in the
56
+ same call. Values other than ``None`` or ``[]`` will result in
57
+ error if ``experiment_names`` is also not ``None`` or ``[]``.
51
58
  ``experiment_names`` is also not ``None`` or ``[]``. ``None`` will
52
59
  default to the active experiment if ``experiment_names`` is ``None``
53
60
  or ``[]``.
54
- filter_string: Filter query string, defaults to searching all runs.
55
- run_view_type: one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``, or
56
- ``ALL`` runs defined in :py:class:`mlflow.entities.ViewType`.
57
- max_results: The maximum number of runs to put in the dataframe. Default
58
- is 100,000 to avoid causing out-of-memory issues on the user's
61
+ filter_string (str): Filter query string, defaults to searching all
62
+ runs.
63
+ run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
64
+ or ``ALL`` runs defined in :py:class:`mlflow.entities.ViewType`.
65
+ max_results (int): The maximum number of runs to put in the dataframe.
66
+ Default is 100,000 to avoid causing out-of-memory issues on the user's
59
67
  machine.
60
- order_by: List of columns to order by (e.g., "metrics.rmse"). The
61
- ``order_by`` column can contain an optional ``DESC`` or ``ASC``
62
- value. The default is ``ASC``. The default ordering is to sort by
68
+ order_by (list[str] | None): List of columns to order by (e.g.,
69
+ "metrics.rmse"). The ``order_by`` column can contain an optional
70
+ ``DESC`` or ``ASC`` value. The default is ``ASC``. The default
71
+ ordering is to sort by ``start_time DESC``, then ``run_id``.
63
72
  ``start_time DESC``, then ``run_id``.
64
- search_all_experiments: Boolean specifying whether all experiments
65
- should be searched. Only honored if ``experiment_ids`` is ``[]`` or
66
- ``None``.
67
- experiment_names: List of experiment names. Search can work with
68
- experiment IDs or experiment names, but not both in the same call.
69
- Values other than ``None`` or ``[]`` will result in error if
73
+ search_all_experiments (bool): Boolean specifying whether all
74
+ experiments should be searched. Only honored if ``experiment_ids``
75
+ is ``[]`` or ``None``.
76
+ experiment_names (list[str] | None): List of experiment names. Search
77
+ can work with experiment IDs or experiment names, but not both in
78
+ the same call. Values other than ``None`` or ``[]`` will result in
79
+ error if ``experiment_ids`` is also not ``None`` or ``[]``.
70
80
  ``experiment_ids`` is also not ``None`` or ``[]``. ``None`` will
71
81
  default to the active experiment if ``experiment_ids`` is ``None``
72
82
  or ``[]``.
@@ -102,10 +112,10 @@ def list_runs(experiment_names: list[str] | None = None) -> RunCollection:
102
112
  The returned runs are sorted by their start time in ascending order.
103
113
 
104
114
  Args:
105
- experiment_names: List of experiment names to search for runs.
106
- If None or an empty list is provided, the function will search
107
- the currently active experiment or all experiments except the
108
- "Default" experiment.
115
+ experiment_names (list[str] | None): List of experiment names to search
116
+ for runs. If None or an empty list is provided, the function will
117
+ search the currently active experiment or all experiments except
118
+ the "Default" experiment.
109
119
 
110
120
  Returns:
111
121
  A `RunCollection` object containing the runs for the specified experiments.
@@ -138,6 +148,22 @@ class RunCollection:
138
148
  def __len__(self) -> int:
139
149
  return len(self._runs)
140
150
 
151
+ def __iter__(self) -> Iterator[Run]:
152
+ return iter(self._runs)
153
+
154
+ def __getitem__(self, index: int) -> Run:
155
+ return self._runs[index]
156
+
157
+ def __contains__(self, run: Run) -> bool:
158
+ return run in self._runs
159
+
160
+ def sort(
161
+ self,
162
+ key: Callable[[Run], Any] | None = None,
163
+ reverse: bool = False,
164
+ ) -> None:
165
+ self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
166
+
141
167
  def first(self) -> Run:
142
168
  """
143
169
  Get the first run in the collection.
@@ -206,9 +232,9 @@ class RunCollection:
206
232
  and exclusive of the upper bound).
207
233
 
208
234
  Args:
209
- config: The configuration object to filter the runs. This can be
210
- any object that provides key-value pairs through the
211
- `iter_params` function.
235
+ config (object | None): The configuration object to filter the runs.
236
+ This can be any object that provides key-value pairs through
237
+ the `iter_params` function.
212
238
  **kwargs: Additional key-value pairs to filter the runs.
213
239
 
214
240
  Returns:
@@ -226,7 +252,7 @@ class RunCollection:
226
252
  is raised.
227
253
 
228
254
  Args:
229
- config: The configuration object to identify the run.
255
+ config (object | None): The configuration object to identify the run.
230
256
  **kwargs: Additional key-value pairs to filter the runs.
231
257
 
232
258
  Returns:
@@ -251,7 +277,7 @@ class RunCollection:
251
277
  returned.
252
278
 
253
279
  Args:
254
- config: The configuration object to identify the run.
280
+ config (object | None): The configuration object to identify the run.
255
281
  **kwargs: Additional key-value pairs to filter the runs.
256
282
 
257
283
  Returns:
@@ -274,7 +300,7 @@ class RunCollection:
274
300
  is raised.
275
301
 
276
302
  Args:
277
- config: The configuration object to identify the run.
303
+ config (object | None): The configuration object to identify the run.
278
304
  **kwargs: Additional key-value pairs to filter the runs.
279
305
 
280
306
  Returns:
@@ -299,7 +325,7 @@ class RunCollection:
299
325
  returned.
300
326
 
301
327
  Args:
302
- config: The configuration object to identify the run.
328
+ config (object | None): The configuration object to identify the run.
303
329
  **kwargs: Additional key-value pairs to filter the runs.
304
330
 
305
331
  Returns:
@@ -322,7 +348,7 @@ class RunCollection:
322
348
  one run matches the criteria, a `ValueError` is raised.
323
349
 
324
350
  Args:
325
- config: The configuration object to identify the run.
351
+ config (object | None): The configuration object to identify the run.
326
352
  **kwargs: Additional key-value pairs to filter the runs.
327
353
 
328
354
  Returns:
@@ -348,7 +374,7 @@ class RunCollection:
348
374
  If more than one run matches the criteria, a `ValueError` is raised.
349
375
 
350
376
  Args:
351
- config: The configuration object to identify the run.
377
+ config (object | None): The configuration object to identify the run.
352
378
  **kwargs: Additional key-value pairs to filter the runs.
353
379
 
354
380
  Returns:
@@ -398,7 +424,8 @@ class RunCollection:
398
424
  results.
399
425
 
400
426
  Args:
401
- func: A function that takes a run and returns a result.
427
+ func (Callable[[Run], T]): A function that takes a run and returns a
428
+ result.
402
429
 
403
430
  Yields:
404
431
  Results obtained by applying the function to each run in the
@@ -412,7 +439,8 @@ class RunCollection:
412
439
  of results.
413
440
 
414
441
  Args:
415
- func: A function that takes a run id and returns a result.
442
+ func (Callable[[str], T]): A function that takes a run id and returns a
443
+ result.
416
444
 
417
445
  Yields:
418
446
  Results obtained by applying the function to each run id in the
@@ -426,8 +454,8 @@ class RunCollection:
426
454
  an iterator of results.
427
455
 
428
456
  Args:
429
- func: A function that takes a run configuration and returns a
430
- result.
457
+ func (Callable[[DictConfig], T]): A function that takes a run
458
+ configuration and returns a result.
431
459
 
432
460
  Yields:
433
461
  Results obtained by applying the function to each run configuration
@@ -445,8 +473,8 @@ class RunCollection:
445
473
  have an artifact URI, None is passed to the function.
446
474
 
447
475
  Args:
448
- func: A function that takes an artifact URI (string or None) and
449
- returns a result.
476
+ func (Callable[[str | None], T]): A function that takes an
477
+ artifact URI (string or None) and returns a result.
450
478
 
451
479
  Yields:
452
480
  Results obtained by applying the function to each artifact URI in the
@@ -464,8 +492,8 @@ class RunCollection:
464
492
  path.
465
493
 
466
494
  Args:
467
- func: A function that takes an artifact directory path (string) and
468
- returns a result.
495
+ func (Callable[[str], T]): A function that takes an artifact directory
496
+ path (string) and returns a result.
469
497
 
470
498
  Yields:
471
499
  Results obtained by applying the function to each artifact directory
@@ -483,9 +511,9 @@ def _param_matches(run: Run, key: str, value: Any) -> bool:
483
511
  and tuples.
484
512
 
485
513
  Args:
486
- run: The run object to check.
487
- key: The parameter key to check.
488
- value: The parameter value to check.
514
+ run (Run): The run object to check.
515
+ key (str): The parameter key to check.
516
+ value (Any): The parameter value to check.
489
517
 
490
518
  Returns:
491
519
  True if the run's parameter matches the specified key-value pair,
@@ -526,10 +554,10 @@ def filter_runs(runs: list[Run], config: object | None = None, **kwargs) -> list
526
554
  exclusive of the upper bound).
527
555
 
528
556
  Args:
529
- runs: The list of runs to filter.
530
- config: The configuration object to filter the runs. This can be any
531
- object that provides key-value pairs through the `iter_params`
532
- function.
557
+ runs (list[Run]): The list of runs to filter.
558
+ config (object | None): The configuration object to filter the runs.
559
+ This can be any object that provides key-value pairs through the
560
+ `iter_params` function.
533
561
  **kwargs: Additional key-value pairs to filter the runs.
534
562
 
535
563
  Returns:
@@ -554,8 +582,8 @@ def find_run(runs: list[Run], config: object | None = None, **kwargs) -> Run:
554
582
  raised.
555
583
 
556
584
  Args:
557
- runs: The runs to filter.
558
- config: The configuration object to identify the run.
585
+ runs (list[Run]): The runs to filter.
586
+ config (object | None): The configuration object to identify the run.
559
587
  **kwargs: Additional key-value pairs to filter the runs.
560
588
 
561
589
  Returns:
@@ -584,8 +612,8 @@ def try_find_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
584
612
  the provided parameters. If no run matches the criteria, None is returned.
585
613
 
586
614
  Args:
587
- runs: The runs to filter.
588
- config: The configuration object to identify the run.
615
+ runs (list[Run]): The runs to filter.
616
+ config (object | None): The configuration object to identify the run.
589
617
  **kwargs: Additional key-value pairs to filter the runs.
590
618
 
591
619
  Returns:
@@ -610,8 +638,8 @@ def find_last_run(runs: list[Run], config: object | None = None, **kwargs) -> Ru
610
638
  is raised.
611
639
 
612
640
  Args:
613
- runs: The runs to filter.
614
- config: The configuration object to identify the run.
641
+ runs (list[Run]): The runs to filter.
642
+ config (object | None): The configuration object to identify the run.
615
643
  **kwargs: Additional key-value pairs to filter the runs.
616
644
 
617
645
  Returns:
@@ -641,8 +669,8 @@ def try_find_last_run(runs: list[Run], config: object | None = None, **kwargs) -
641
669
  the provided parameters. If no run matches the criteria, None is returned.
642
670
 
643
671
  Args:
644
- runs: The runs to filter.
645
- config: The configuration object to identify the run.
672
+ runs (list[Run]): The runs to filter.
673
+ config (object | None): The configuration object to identify the run.
646
674
  **kwargs: Additional key-value pairs to filter the runs.
647
675
 
648
676
  Returns:
@@ -667,8 +695,8 @@ def get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run:
667
695
  than one run matches the criteria, a `ValueError` is raised.
668
696
 
669
697
  Args:
670
- runs: The runs to filter.
671
- config: The configuration object to identify the run.
698
+ runs (list[Run]): The runs to filter.
699
+ config (object | None): The configuration object to identify the run.
672
700
  **kwargs: Additional key-value pairs to filter the runs.
673
701
 
674
702
  Returns:
@@ -707,8 +735,8 @@ def try_get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
707
735
  If more than one run matches the criteria, a `ValueError` is raised.
708
736
 
709
737
  Args:
710
- runs: The runs to filter.
711
- config: The configuration object to identify the run.
738
+ runs (list[Run]): The runs to filter.
739
+ config (object | None): The configuration object to identify the run.
712
740
  **kwargs: Additional key-value pairs to filter the runs.
713
741
 
714
742
  Returns:
@@ -746,7 +774,7 @@ def get_param_names(runs: list[Run]) -> list[str]:
746
774
  set to ensure uniqueness.
747
775
 
748
776
  Args:
749
- runs: The list of runs from which to extract parameter names.
777
+ runs (list[Run]): The list of runs from which to extract parameter names.
750
778
 
751
779
  Returns:
752
780
  A list of unique parameter names.
@@ -770,7 +798,8 @@ def get_param_dict(runs: list[Run]) -> dict[str, list[str]]:
770
798
  and the values are lists of parameter values.
771
799
 
772
800
  Args:
773
- runs: The list of runs from which to extract parameter names and values.
801
+ runs (list[Run]): The list of runs from which to extract parameter names
802
+ and values.
774
803
 
775
804
  Returns:
776
805
  A dictionary where the keys are parameter names and the values are lists
@@ -795,7 +824,7 @@ def load_config(run: Run) -> DictConfig:
795
824
  `.hydra/config.yaml` is not found in the run's artifact directory.
796
825
 
797
826
  Args:
798
- run: The Run instance for which to load the configuration.
827
+ run (Run): The Run instance for which to load the configuration.
799
828
 
800
829
  Returns:
801
830
  The loaded configuration as a DictConfig object. Returns an empty
@@ -813,37 +842,3 @@ def _load_config(run_id: str) -> DictConfig:
813
842
  return DictConfig({})
814
843
 
815
844
  return OmegaConf.load(path) # type: ignore
816
-
817
-
818
- # def get_hydra_output_dir(run: Run_ | Series | str) -> Path:
819
- # """
820
- # Get the Hydra output directory.
821
-
822
- # Args:
823
- # run: The run object.
824
-
825
- # Returns:
826
- # Path: The Hydra output directory.
827
- # """
828
- # path = get_artifact_dir(run) / ".hydra/hydra.yaml"
829
-
830
- # if path.exists():
831
- # hc = OmegaConf.load(path)
832
- # return Path(hc.hydra.runtime.output_dir)
833
-
834
- # raise FileNotFoundError
835
-
836
-
837
- # def log_hydra_output_dir(run: Run_ | Series | str) -> None:
838
- # """
839
- # Log the Hydra output directory.
840
-
841
- # Args:
842
- # run: The run object.
843
-
844
- # Returns:
845
- # None
846
- # """
847
- # output_dir = get_hydra_output_dir(run)
848
- # run_id = run if isinstance(run, str) else run.info.run_id
849
- # mlflow.log_artifacts(output_dir.as_posix(), run_id=run_id)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -99,13 +99,10 @@ def my_app(cfg: MySQLConfig) -> None:
99
99
  # Set experiment by Hydra job name.
100
100
  hydraflow.set_experiment()
101
101
 
102
- # Automatically log params using Hydra config.
103
- with mlflow.start_run(), hydraflow.log_run(cfg) as info:
102
+ # Automatically log Hydra config as params.
103
+ with hydraflow.start_run():
104
104
  # Your app code below.
105
105
 
106
- # `info.output_dir` is the Hydra output directory.
107
- # `info.artifact_dir` is the MLflow artifact directory.
108
-
109
106
  with hydraflow.watch(callback):
110
107
  # Watch files in the MLflow artifact directory.
111
108
  # You can update metrics or log other artifacts
@@ -0,0 +1,10 @@
1
+ hydraflow/__init__.py,sha256=l5BrZAfpJHFkQnDRuETZVjDTntMmzOI3CUwnsm2fGzk,460
2
+ hydraflow/asyncio.py,sha256=yh851L315QHzRBwq6r-uwO2oZKgz1JawHp-fswfxT1E,6175
3
+ hydraflow/config.py,sha256=6TCKNQZ3sSrIEvl245T2udwFuknejyN1dMcIVmOHdrQ,2102
4
+ hydraflow/context.py,sha256=8Qn99yCSkCarDDthQ6hjgW80CBBIg0H7fnLvtw4ZXo8,7248
5
+ hydraflow/mlflow.py,sha256=gGr0fvFEllduA-ByHMeEamM39zVY_30tjtEbkSZ4lHA,3659
6
+ hydraflow/runs.py,sha256=0t2xhjV9DMA1CNDzBYrsHiZrDZ6cNsaSTxi0ikf6k8c,29907
7
+ hydraflow-0.2.4.dist-info/METADATA,sha256=Rw8m1Ir6Lio6jja44oPHnSMdlLbK2KtZ46UQRD38Lq8,4148
8
+ hydraflow-0.2.4.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ hydraflow-0.2.4.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
10
+ hydraflow-0.2.4.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- hydraflow/__init__.py,sha256=9v7p2ezUd_LMoRJQS0ay8c7fpaqPZ6Ofq7YPT0rSO5I,528
2
- hydraflow/asyncio.py,sha256=yh851L315QHzRBwq6r-uwO2oZKgz1JawHp-fswfxT1E,6175
3
- hydraflow/config.py,sha256=FNTuCppjCMrZKVByJMrWKbgj3HeMWWwAmQNoyFe029Y,2087
4
- hydraflow/context.py,sha256=MqkEhKEZL_N3eb3v5u9D4EqKkiSmiPyXXafhPkALRlg,5129
5
- hydraflow/mlflow.py,sha256=_Los9E38eG8sTiN8bGwZmvjCrS0S-wSGiA4fyhQM3Zw,2251
6
- hydraflow/runs.py,sha256=0BXSBbNkELP3CzaCGBkejOkpyk5uQUxrdknJPRwR400,29022
7
- hydraflow-0.2.3.dist-info/METADATA,sha256=h5Pxy6EnxTlyyGL8NRr14ZHtLhA9ldmM9GP5sES6KWU,4304
8
- hydraflow-0.2.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- hydraflow-0.2.3.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
10
- hydraflow-0.2.3.dist-info/RECORD,,