hydraflow 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hydraflow/__init__.py CHANGED
@@ -4,7 +4,13 @@ from .context import chdir_artifact, chdir_hydra, log_run, start_run, watch
4
4
  from .mlflow import list_runs, search_runs, set_experiment
5
5
  from .progress import multi_tasks_progress, parallel_progress
6
6
  from .run_collection import RunCollection
7
- from .utils import get_artifact_dir, get_hydra_output_dir, load_config
7
+ from .utils import (
8
+ get_artifact_dir,
9
+ get_hydra_output_dir,
10
+ get_overrides,
11
+ load_config,
12
+ load_overrides,
13
+ )
8
14
 
9
15
  __all__ = [
10
16
  "RunCollection",
@@ -12,8 +18,10 @@ __all__ = [
12
18
  "chdir_hydra",
13
19
  "get_artifact_dir",
14
20
  "get_hydra_output_dir",
21
+ "get_overrides",
15
22
  "list_runs",
16
23
  "load_config",
24
+ "load_overrides",
17
25
  "log_run",
18
26
  "multi_tasks_progress",
19
27
  "parallel_progress",
hydraflow/param.py CHANGED
@@ -72,4 +72,4 @@ def _match_tuple(param: str, value: tuple) -> bool | None:
72
72
  if type(value[0]) is not type(value[1]):
73
73
  return None
74
74
 
75
- return value[0] <= type(value[0])(param) < value[1] # type: ignore
75
+ return value[0] <= type(value[0])(param) <= value[1] # type: ignore
@@ -24,12 +24,12 @@ from itertools import chain
24
24
  from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
25
25
 
26
26
  from mlflow.entities import RunStatus
27
- from polars.dataframe import DataFrame
28
27
 
29
28
  import hydraflow.param
30
- from hydraflow.config import collect_params, iter_params
29
+ from hydraflow.config import iter_params
31
30
  from hydraflow.run_data import RunCollectionData
32
31
  from hydraflow.run_info import RunCollectionInfo
32
+ from hydraflow.utils import load_config
33
33
 
34
34
  if TYPE_CHECKING:
35
35
  from collections.abc import Callable, Iterator
@@ -239,8 +239,8 @@ class RunCollection:
239
239
  The filtering supports:
240
240
  - Exact matches for single values.
241
241
  - Membership checks for lists of values.
242
- - Range checks for tuples of two values (inclusive of the lower bound
243
- and exclusive of the upper bound).
242
+ - Range checks for tuples of two values (inclusive of both the lower
243
+ and upper bound).
244
244
 
245
245
  Args:
246
246
  config (object | None): The configuration object to filter the runs.
@@ -476,7 +476,7 @@ class RunCollection:
476
476
  """
477
477
  return (func(run, *args, **kwargs) for run in self)
478
478
 
479
- def map_run_id(
479
+ def map_id(
480
480
  self,
481
481
  func: Callable[Concatenate[str, P], T],
482
482
  *args: P.args,
@@ -516,7 +516,7 @@ class RunCollection:
516
516
  in the collection.
517
517
 
518
518
  """
519
- return (func(config, *args, **kwargs) for config in self.data.config)
519
+ return (func(load_config(run), *args, **kwargs) for run in self)
520
520
 
521
521
  def map_uri(
522
522
  self,
@@ -569,8 +569,8 @@ class RunCollection:
569
569
 
570
570
  def group_by(
571
571
  self,
572
- *names: str | list[str],
573
- ) -> dict[tuple[str | None, ...], RunCollection]:
572
+ names: str | list[str],
573
+ ) -> dict[str | None | tuple[str | None, ...], RunCollection]:
574
574
  """Group runs by specified parameter names.
575
575
 
576
576
  Group the runs in the collection based on the values of the
@@ -578,33 +578,27 @@ class RunCollection:
578
578
  form a key in the returned dictionary.
579
579
 
580
580
  Args:
581
- *names (str | list[str]): The names of the parameters to group by.
581
+ names (str | list[str]): The names of the parameters to group by.
582
582
  This can be a single parameter name or multiple names provided
583
583
  as separate arguments or as a list.
584
584
 
585
585
  Returns:
586
- dict[tuple[str | None, ...], RunCollection]: A dictionary where the keys
587
- are tuples of parameter values and the values are RunCollection objects
588
- containing the runs that match those parameter values.
586
+ dict[str | None | tuple[str | None, ...], RunCollection]: A
587
+ dictionary where the keys are tuples of parameter values and the
588
+ values are `RunCollection` objects containing the runs that match
589
+ those parameter values.
589
590
 
590
591
  """
591
- grouped_runs: dict[tuple[str | None, ...], list[Run]] = {}
592
+ grouped_runs: dict[str | None | tuple[str | None, ...], list[Run]] = {}
593
+ is_list = isinstance(names, list)
592
594
  for run in self._runs:
593
- key = get_params(run, *names)
595
+ key = get_params(run, names)
596
+ if not is_list:
597
+ key = key[0]
594
598
  grouped_runs.setdefault(key, []).append(run)
595
599
 
596
600
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
597
601
 
598
- @property
599
- def config(self) -> DataFrame:
600
- """Get the runs' configurations as a polars DataFrame.
601
-
602
- Returns:
603
- A polars DataFrame containing the runs' configurations.
604
-
605
- """
606
- return DataFrame(self.map_config(collect_params))
607
-
608
602
 
609
603
  def _param_matches(run: Run, key: str, value: Any) -> bool:
610
604
  params = run.data.params
@@ -637,8 +631,8 @@ def filter_runs(
637
631
  The filtering supports:
638
632
  - Exact matches for single values.
639
633
  - Membership checks for lists of values.
640
- - Range checks for tuples of two values (inclusive of the lower bound and
641
- exclusive of the upper bound).
634
+ - Range checks for tuples of two values (inclusive of both the lower and
635
+ upper bound).
642
636
 
643
637
  Args:
644
638
  runs (list[Run]): The list of runs to filter.
hydraflow/run_data.py CHANGED
@@ -4,10 +4,13 @@ from __future__ import annotations
4
4
 
5
5
  from typing import TYPE_CHECKING
6
6
 
7
- from hydraflow.utils import load_config
7
+ from polars.dataframe import DataFrame
8
+
9
+ from hydraflow.config import collect_params
8
10
 
9
11
  if TYPE_CHECKING:
10
- from omegaconf import DictConfig
12
+ from collections.abc import Iterable
13
+ from typing import Any
11
14
 
12
15
  from hydraflow.run_collection import RunCollection
13
16
 
@@ -19,16 +22,36 @@ class RunCollectionData:
19
22
  self._runs = runs
20
23
 
21
24
  @property
22
- def params(self) -> list[dict[str, str]]:
25
+ def params(self) -> dict[str, list[str]]:
23
26
  """Get the parameters for each run in the collection."""
24
- return [run.data.params for run in self._runs]
27
+ return _to_dict(run.data.params for run in self._runs)
25
28
 
26
29
  @property
27
- def metrics(self) -> list[dict[str, float]]:
30
+ def metrics(self) -> dict[str, list[float]]:
28
31
  """Get the metrics for each run in the collection."""
29
- return [run.data.metrics for run in self._runs]
32
+ return _to_dict(run.data.metrics for run in self._runs)
30
33
 
31
34
  @property
32
- def config(self) -> list[DictConfig]:
33
- """Get the configuration for each run in the collection."""
34
- return [load_config(run) for run in self._runs]
35
+ def config(self) -> DataFrame:
36
+ """Get the runs' configurations as a polars DataFrame.
37
+
38
+ Returns:
39
+ A polars DataFrame containing the runs' configurations.
40
+
41
+ """
42
+ return DataFrame(self._runs.map_config(collect_params))
43
+
44
+
45
+ def _to_dict(it: Iterable[dict[str, Any]]) -> dict[str, list[Any]]:
46
+ """Convert an iterable of dictionaries to a dictionary of lists."""
47
+ data = list(it)
48
+ if not data:
49
+ return {}
50
+
51
+ keys = []
52
+ for d in data:
53
+ for key in d:
54
+ if key not in keys:
55
+ keys.append(key)
56
+
57
+ return {key: [x.get(key) for x in data] for key in keys}
hydraflow/utils.py CHANGED
@@ -68,6 +68,11 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
68
68
  raise FileNotFoundError
69
69
 
70
70
 
71
+ def get_overrides() -> list[str]:
72
+ """Retrieve the overrides for the current run."""
73
+ return HydraConfig.get().overrides.task
74
+
75
+
71
76
  def load_config(run: Run) -> DictConfig:
72
77
  """Load the configuration for a given run.
73
78
 
@@ -86,3 +91,23 @@ def load_config(run: Run) -> DictConfig:
86
91
  """
87
92
  path = get_artifact_dir(run) / ".hydra/config.yaml"
88
93
  return OmegaConf.load(path) # type: ignore
94
+
95
+
96
+ def load_overrides(run: Run) -> list[str]:
97
+ """Load the overrides for a given run.
98
+
99
+ This function loads the overrides for the provided Run instance
100
+ by downloading the overrides file from the MLflow artifacts and
101
+ loading it using OmegaConf. It returns an empty config if
102
+ `.hydra/overrides.yaml` is not found in the run's artifact directory.
103
+
104
+ Args:
105
+ run (Run): The Run instance for which to load the overrides.
106
+
107
+ Returns:
108
+ The loaded overrides as a list of strings. Returns an empty list
109
+ if the overrides file is not found.
110
+
111
+ """
112
+ path = get_artifact_dir(run) / ".hydra/overrides.yaml"
113
+ return [str(x) for x in OmegaConf.load(path)]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -0,0 +1,16 @@
1
+ hydraflow/__init__.py,sha256=aZWgotgg9BAYGOk6WAocOdo1M9ydAg_i7SThFeuLNo8,797
2
+ hydraflow/asyncio.py,sha256=-i1C8KAmNDImrjHnk92Csaa1mpjdK8Vp4ZVaQV-l94s,6634
3
+ hydraflow/config.py,sha256=6V5omJ3-h9-ZwVpM5rTA4FqE_mu8urTy9OqV4zG79gw,2671
4
+ hydraflow/context.py,sha256=412884e84qIEYtbxJT4roYsKfldGaTKzgo6Q1FAsT5U,8733
5
+ hydraflow/mlflow.py,sha256=JELqXFCJ9MsEJaQWT5dyleTFk8BHL7cQwW_gzhkPoIg,8729
6
+ hydraflow/param.py,sha256=QkLeQvt5ZF3GyRGnhP66o0GElc1ZOOCxCL7PdyfIUbA,1939
7
+ hydraflow/progress.py,sha256=zvKX1HCN8_xDOsgYOEcLLhkhdPdep-U8vHrc0XZ-6SQ,6163
8
+ hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ hydraflow/run_collection.py,sha256=OiPibp8zM4gpsvqpHr9sBh4_1zmHRuXfPcmen7xND-s,24792
10
+ hydraflow/run_data.py,sha256=qeFX1iRvNAorXA9QQIjzr0o2_82TI44eZKp7llKG8GI,1549
11
+ hydraflow/run_info.py,sha256=sMXOo20ClaRIommMEzuAbO_OrcXx7M1Yt4FMV7spxz0,998
12
+ hydraflow/utils.py,sha256=XFZkUNQ6amYrlSJHIBoQvrxmDXwQG-M7T9BPpqid9Bc,3500
13
+ hydraflow-0.3.3.dist-info/METADATA,sha256=IH_V71WNJTLRItKdVASVm99ZmfXWQYj365RL_qDC_aY,3840
14
+ hydraflow-0.3.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
+ hydraflow-0.3.3.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
+ hydraflow-0.3.3.dist-info/RECORD,,
@@ -1,16 +0,0 @@
1
- hydraflow/__init__.py,sha256=6sfM1ashUkfrNf7lOR7raFYhG8YdOAJR-JgRNL_IVo8,698
2
- hydraflow/asyncio.py,sha256=-i1C8KAmNDImrjHnk92Csaa1mpjdK8Vp4ZVaQV-l94s,6634
3
- hydraflow/config.py,sha256=6V5omJ3-h9-ZwVpM5rTA4FqE_mu8urTy9OqV4zG79gw,2671
4
- hydraflow/context.py,sha256=412884e84qIEYtbxJT4roYsKfldGaTKzgo6Q1FAsT5U,8733
5
- hydraflow/mlflow.py,sha256=JELqXFCJ9MsEJaQWT5dyleTFk8BHL7cQwW_gzhkPoIg,8729
6
- hydraflow/param.py,sha256=dvIXcKgc_MPiju3WEk9qz5FOUeK5qSj-YWN2ophCpUM,1938
7
- hydraflow/progress.py,sha256=zvKX1HCN8_xDOsgYOEcLLhkhdPdep-U8vHrc0XZ-6SQ,6163
8
- hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- hydraflow/run_collection.py,sha256=Xv6-KD5ac-vv-4Q3PZrzJy1x84H_g7UoP7ZqZ8_DQeQ,24973
10
- hydraflow/run_data.py,sha256=ZXVr0PHyufH9wwyQYWtpE4_MheAC2ArTW_J1TTMQ4iI,983
11
- hydraflow/run_info.py,sha256=sMXOo20ClaRIommMEzuAbO_OrcXx7M1Yt4FMV7spxz0,998
12
- hydraflow/utils.py,sha256=aRdBdToKfvHhN2qFiRzPHIdQxS7cTpZREQeP8HreAfI,2676
13
- hydraflow-0.3.1.dist-info/METADATA,sha256=W38pNcCNy7Kmx1t9dwFoANsRjCk40-KBJUWux_BvHqA,3840
14
- hydraflow-0.3.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
- hydraflow-0.3.1.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
- hydraflow-0.3.1.dist-info/RECORD,,