hydraflow 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
hydraflow/__init__.py CHANGED
@@ -4,7 +4,13 @@ from .context import chdir_artifact, chdir_hydra, log_run, start_run, watch
4
4
  from .mlflow import list_runs, search_runs, set_experiment
5
5
  from .progress import multi_tasks_progress, parallel_progress
6
6
  from .run_collection import RunCollection
7
- from .utils import get_artifact_dir, get_hydra_output_dir, load_config
7
+ from .utils import (
8
+ get_artifact_dir,
9
+ get_hydra_output_dir,
10
+ get_overrides,
11
+ load_config,
12
+ load_overrides,
13
+ )
8
14
 
9
15
  __all__ = [
10
16
  "RunCollection",
@@ -12,8 +18,10 @@ __all__ = [
12
18
  "chdir_hydra",
13
19
  "get_artifact_dir",
14
20
  "get_hydra_output_dir",
21
+ "get_overrides",
15
22
  "list_runs",
16
23
  "load_config",
24
+ "load_overrides",
17
25
  "log_run",
18
26
  "multi_tasks_progress",
19
27
  "parallel_progress",
hydraflow/param.py CHANGED
@@ -72,4 +72,4 @@ def _match_tuple(param: str, value: tuple) -> bool | None:
72
72
  if type(value[0]) is not type(value[1]):
73
73
  return None
74
74
 
75
- return value[0] <= type(value[0])(param) < value[1] # type: ignore
75
+ return value[0] <= type(value[0])(param) <= value[1] # type: ignore
@@ -24,12 +24,12 @@ from itertools import chain
24
24
  from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
25
25
 
26
26
  from mlflow.entities import RunStatus
27
- from polars.dataframe import DataFrame
28
27
 
29
28
  import hydraflow.param
30
- from hydraflow.config import collect_params, iter_params
29
+ from hydraflow.config import iter_params
31
30
  from hydraflow.run_data import RunCollectionData
32
31
  from hydraflow.run_info import RunCollectionInfo
32
+ from hydraflow.utils import load_config
33
33
 
34
34
  if TYPE_CHECKING:
35
35
  from collections.abc import Callable, Iterator
@@ -239,8 +239,8 @@ class RunCollection:
239
239
  The filtering supports:
240
240
  - Exact matches for single values.
241
241
  - Membership checks for lists of values.
242
- - Range checks for tuples of two values (inclusive of the lower bound
243
- and exclusive of the upper bound).
242
+ - Range checks for tuples of two values (inclusive of both the lower
243
+ and upper bound).
244
244
 
245
245
  Args:
246
246
  config (object | None): The configuration object to filter the runs.
@@ -476,7 +476,7 @@ class RunCollection:
476
476
  """
477
477
  return (func(run, *args, **kwargs) for run in self)
478
478
 
479
- def map_run_id(
479
+ def map_id(
480
480
  self,
481
481
  func: Callable[Concatenate[str, P], T],
482
482
  *args: P.args,
@@ -516,7 +516,7 @@ class RunCollection:
516
516
  in the collection.
517
517
 
518
518
  """
519
- return (func(config, *args, **kwargs) for config in self.data.config)
519
+ return (func(load_config(run), *args, **kwargs) for run in self)
520
520
 
521
521
  def map_uri(
522
522
  self,
@@ -569,8 +569,8 @@ class RunCollection:
569
569
 
570
570
  def group_by(
571
571
  self,
572
- *names: str | list[str],
573
- ) -> dict[tuple[str | None, ...], RunCollection]:
572
+ names: str | list[str],
573
+ ) -> dict[str | None | tuple[str | None, ...], RunCollection]:
574
574
  """Group runs by specified parameter names.
575
575
 
576
576
  Group the runs in the collection based on the values of the
@@ -578,33 +578,27 @@ class RunCollection:
578
578
  form a key in the returned dictionary.
579
579
 
580
580
  Args:
581
- *names (str | list[str]): The names of the parameters to group by.
581
+ names (str | list[str]): The names of the parameters to group by.
582
582
  This can be a single parameter name or multiple names provided
583
583
  as separate arguments or as a list.
584
584
 
585
585
  Returns:
586
- dict[tuple[str | None, ...], RunCollection]: A dictionary where the keys
587
- are tuples of parameter values and the values are RunCollection objects
588
- containing the runs that match those parameter values.
586
+ dict[str | None | tuple[str | None, ...], RunCollection]: A
587
+ dictionary where the keys are tuples of parameter values and the
588
+ values are `RunCollection` objects containing the runs that match
589
+ those parameter values.
589
590
 
590
591
  """
591
- grouped_runs: dict[tuple[str | None, ...], list[Run]] = {}
592
+ grouped_runs: dict[str | None | tuple[str | None, ...], list[Run]] = {}
593
+ is_list = isinstance(names, list)
592
594
  for run in self._runs:
593
- key = get_params(run, *names)
595
+ key = get_params(run, names)
596
+ if not is_list:
597
+ key = key[0]
594
598
  grouped_runs.setdefault(key, []).append(run)
595
599
 
596
600
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
597
601
 
598
- @property
599
- def config(self) -> DataFrame:
600
- """Get the runs' configurations as a polars DataFrame.
601
-
602
- Returns:
603
- A polars DataFrame containing the runs' configurations.
604
-
605
- """
606
- return DataFrame(self.map_config(collect_params))
607
-
608
602
 
609
603
  def _param_matches(run: Run, key: str, value: Any) -> bool:
610
604
  params = run.data.params
@@ -637,8 +631,8 @@ def filter_runs(
637
631
  The filtering supports:
638
632
  - Exact matches for single values.
639
633
  - Membership checks for lists of values.
640
- - Range checks for tuples of two values (inclusive of the lower bound and
641
- exclusive of the upper bound).
634
+ - Range checks for tuples of two values (inclusive of both the lower and
635
+ upper bound).
642
636
 
643
637
  Args:
644
638
  runs (list[Run]): The list of runs to filter.
hydraflow/run_data.py CHANGED
@@ -4,10 +4,13 @@ from __future__ import annotations
4
4
 
5
5
  from typing import TYPE_CHECKING
6
6
 
7
- from hydraflow.utils import load_config
7
+ from polars.dataframe import DataFrame
8
+
9
+ from hydraflow.config import collect_params
8
10
 
9
11
  if TYPE_CHECKING:
10
- from omegaconf import DictConfig
12
+ from collections.abc import Iterable
13
+ from typing import Any
11
14
 
12
15
  from hydraflow.run_collection import RunCollection
13
16
 
@@ -19,16 +22,36 @@ class RunCollectionData:
19
22
  self._runs = runs
20
23
 
21
24
  @property
22
- def params(self) -> list[dict[str, str]]:
25
+ def params(self) -> dict[str, list[str]]:
23
26
  """Get the parameters for each run in the collection."""
24
- return [run.data.params for run in self._runs]
27
+ return _to_dict(run.data.params for run in self._runs)
25
28
 
26
29
  @property
27
- def metrics(self) -> list[dict[str, float]]:
30
+ def metrics(self) -> dict[str, list[float]]:
28
31
  """Get the metrics for each run in the collection."""
29
- return [run.data.metrics for run in self._runs]
32
+ return _to_dict(run.data.metrics for run in self._runs)
30
33
 
31
34
  @property
32
- def config(self) -> list[DictConfig]:
33
- """Get the configuration for each run in the collection."""
34
- return [load_config(run) for run in self._runs]
35
+ def config(self) -> DataFrame:
36
+ """Get the runs' configurations as a polars DataFrame.
37
+
38
+ Returns:
39
+ A polars DataFrame containing the runs' configurations.
40
+
41
+ """
42
+ return DataFrame(self._runs.map_config(collect_params))
43
+
44
+
45
+ def _to_dict(it: Iterable[dict[str, Any]]) -> dict[str, list[Any]]:
46
+ """Convert an iterable of dictionaries to a dictionary of lists."""
47
+ data = list(it)
48
+ if not data:
49
+ return {}
50
+
51
+ keys = []
52
+ for d in data:
53
+ for key in d:
54
+ if key not in keys:
55
+ keys.append(key)
56
+
57
+ return {key: [x.get(key) for x in data] for key in keys}
hydraflow/utils.py CHANGED
@@ -68,6 +68,11 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
68
68
  raise FileNotFoundError
69
69
 
70
70
 
71
+ def get_overrides() -> list[str]:
72
+ """Retrieve the overrides for the current run."""
73
+ return HydraConfig.get().overrides.task
74
+
75
+
71
76
  def load_config(run: Run) -> DictConfig:
72
77
  """Load the configuration for a given run.
73
78
 
@@ -86,3 +91,23 @@ def load_config(run: Run) -> DictConfig:
86
91
  """
87
92
  path = get_artifact_dir(run) / ".hydra/config.yaml"
88
93
  return OmegaConf.load(path) # type: ignore
94
+
95
+
96
+ def load_overrides(run: Run) -> list[str]:
97
+ """Load the overrides for a given run.
98
+
99
+ This function loads the overrides for the provided Run instance
100
+ by downloading the overrides file from the MLflow artifacts and
101
+ loading it using OmegaConf. It returns an empty config if
102
+ `.hydra/overrides.yaml` is not found in the run's artifact directory.
103
+
104
+ Args:
105
+ run (Run): The Run instance for which to load the overrides.
106
+
107
+ Returns:
108
+ The loaded overrides as a list of strings. Returns an empty list
109
+ if the overrides file is not found.
110
+
111
+ """
112
+ path = get_artifact_dir(run) / ".hydra/overrides.yaml"
113
+ return [str(x) for x in OmegaConf.load(path)]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -0,0 +1,16 @@
1
+ hydraflow/__init__.py,sha256=aZWgotgg9BAYGOk6WAocOdo1M9ydAg_i7SThFeuLNo8,797
2
+ hydraflow/asyncio.py,sha256=-i1C8KAmNDImrjHnk92Csaa1mpjdK8Vp4ZVaQV-l94s,6634
3
+ hydraflow/config.py,sha256=6V5omJ3-h9-ZwVpM5rTA4FqE_mu8urTy9OqV4zG79gw,2671
4
+ hydraflow/context.py,sha256=412884e84qIEYtbxJT4roYsKfldGaTKzgo6Q1FAsT5U,8733
5
+ hydraflow/mlflow.py,sha256=JELqXFCJ9MsEJaQWT5dyleTFk8BHL7cQwW_gzhkPoIg,8729
6
+ hydraflow/param.py,sha256=QkLeQvt5ZF3GyRGnhP66o0GElc1ZOOCxCL7PdyfIUbA,1939
7
+ hydraflow/progress.py,sha256=zvKX1HCN8_xDOsgYOEcLLhkhdPdep-U8vHrc0XZ-6SQ,6163
8
+ hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ hydraflow/run_collection.py,sha256=OiPibp8zM4gpsvqpHr9sBh4_1zmHRuXfPcmen7xND-s,24792
10
+ hydraflow/run_data.py,sha256=qeFX1iRvNAorXA9QQIjzr0o2_82TI44eZKp7llKG8GI,1549
11
+ hydraflow/run_info.py,sha256=sMXOo20ClaRIommMEzuAbO_OrcXx7M1Yt4FMV7spxz0,998
12
+ hydraflow/utils.py,sha256=XFZkUNQ6amYrlSJHIBoQvrxmDXwQG-M7T9BPpqid9Bc,3500
13
+ hydraflow-0.3.3.dist-info/METADATA,sha256=IH_V71WNJTLRItKdVASVm99ZmfXWQYj365RL_qDC_aY,3840
14
+ hydraflow-0.3.3.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
+ hydraflow-0.3.3.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
+ hydraflow-0.3.3.dist-info/RECORD,,
@@ -1,16 +0,0 @@
1
- hydraflow/__init__.py,sha256=6sfM1ashUkfrNf7lOR7raFYhG8YdOAJR-JgRNL_IVo8,698
2
- hydraflow/asyncio.py,sha256=-i1C8KAmNDImrjHnk92Csaa1mpjdK8Vp4ZVaQV-l94s,6634
3
- hydraflow/config.py,sha256=6V5omJ3-h9-ZwVpM5rTA4FqE_mu8urTy9OqV4zG79gw,2671
4
- hydraflow/context.py,sha256=412884e84qIEYtbxJT4roYsKfldGaTKzgo6Q1FAsT5U,8733
5
- hydraflow/mlflow.py,sha256=JELqXFCJ9MsEJaQWT5dyleTFk8BHL7cQwW_gzhkPoIg,8729
6
- hydraflow/param.py,sha256=dvIXcKgc_MPiju3WEk9qz5FOUeK5qSj-YWN2ophCpUM,1938
7
- hydraflow/progress.py,sha256=zvKX1HCN8_xDOsgYOEcLLhkhdPdep-U8vHrc0XZ-6SQ,6163
8
- hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- hydraflow/run_collection.py,sha256=Xv6-KD5ac-vv-4Q3PZrzJy1x84H_g7UoP7ZqZ8_DQeQ,24973
10
- hydraflow/run_data.py,sha256=ZXVr0PHyufH9wwyQYWtpE4_MheAC2ArTW_J1TTMQ4iI,983
11
- hydraflow/run_info.py,sha256=sMXOo20ClaRIommMEzuAbO_OrcXx7M1Yt4FMV7spxz0,998
12
- hydraflow/utils.py,sha256=aRdBdToKfvHhN2qFiRzPHIdQxS7cTpZREQeP8HreAfI,2676
13
- hydraflow-0.3.1.dist-info/METADATA,sha256=W38pNcCNy7Kmx1t9dwFoANsRjCk40-KBJUWux_BvHqA,3840
14
- hydraflow-0.3.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
- hydraflow-0.3.1.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
- hydraflow-0.3.1.dist-info/RECORD,,