hydraflow 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
hydraflow/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
+ from .config import select_config, select_overrides
3
4
  from .context import chdir_artifact, chdir_hydra_output, log_run, start_run, watch
4
5
  from .mlflow import list_runs, search_runs, set_experiment
5
6
  from .progress import multi_tasks_progress, parallel_progress
@@ -26,6 +27,8 @@ __all__ = [
26
27
  "multi_tasks_progress",
27
28
  "parallel_progress",
28
29
  "search_runs",
30
+ "select_config",
31
+ "select_overrides",
29
32
  "set_experiment",
30
33
  "start_run",
31
34
  "watch",
hydraflow/config.py CHANGED
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
6
6
 
7
7
  from omegaconf import DictConfig, ListConfig, OmegaConf
8
8
 
9
+ from hydraflow.utils import get_overrides
10
+
9
11
  if TYPE_CHECKING:
10
12
  from collections.abc import Iterator
11
13
  from typing import Any
@@ -99,3 +101,39 @@ def _convert(value: Any) -> Any:
99
101
  return list(value)
100
102
 
101
103
  return value
104
+
105
+
106
+ def select_config(config: object, names: list[str]) -> dict[str, Any]:
107
+ """Select the given parameters from the configuration object.
108
+
109
+ This function selects the given parameters from the configuration object
110
+ and returns a new configuration object containing only the selected parameters.
111
+
112
+ Args:
113
+ config (object): The configuration object to select parameters from.
114
+ names (list[str]): The names of the parameters to select.
115
+
116
+ Returns:
117
+ DictConfig: A new configuration object containing only the selected parameters.
118
+
119
+ """
120
+ if not isinstance(config, DictConfig):
121
+ config = OmegaConf.structured(config)
122
+
123
+ return {name: _get(config, name) for name in names} # type: ignore
124
+
125
+
126
+ def _get(config: DictConfig, name: str) -> Any:
127
+ """Get the value of the given parameter from the configuration object."""
128
+ if "." not in name:
129
+ return config.get(name)
130
+
131
+ prefix, name = name.split(".", 1)
132
+ return _get(config.get(prefix), name)
133
+
134
+
135
+ def select_overrides(config: object) -> dict[str, Any]:
136
+ """Select the given overrides from the configuration object."""
137
+ overrides = get_overrides()
138
+ names = [override.split("=")[0].strip() for override in overrides]
139
+ return select_config(config, names)
@@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
26
26
  from mlflow.entities import RunStatus
27
27
 
28
28
  import hydraflow.param
29
- from hydraflow.config import iter_params
29
+ from hydraflow.config import iter_params, select_config, select_overrides
30
30
  from hydraflow.run_data import RunCollectionData
31
31
  from hydraflow.run_info import RunCollectionInfo
32
32
  from hydraflow.utils import load_config
@@ -616,6 +616,8 @@ def filter_runs(
616
616
  runs: list[Run],
617
617
  config: object | None = None,
618
618
  *,
619
+ override: bool = False,
620
+ select: list[str] | None = None,
619
621
  status: str | list[str] | int | list[int] | None = None,
620
622
  **kwargs,
621
623
  ) -> list[Run]:
@@ -636,17 +638,26 @@ def filter_runs(
636
638
 
637
639
  Args:
638
640
  runs (list[Run]): The list of runs to filter.
639
- config (object | None): The configuration object to filter the runs.
640
- This can be any object that provides key-value pairs through the
641
- `iter_params` function.
642
- status (str | list[str] | RunStatus | list[RunStatus] | None): The status of
643
- the runs to filter.
641
+ config (object | None, optional): The configuration object to filter the
642
+ runs. This can be any object that provides key-value pairs through
643
+ the `iter_params` function. Defaults to None.
644
+ override (bool, optional): If True, filter the runs based on
645
+ the overrides. Defaults to False.
646
+ select (list[str] | None, optional): The list of parameters to select.
647
+ Defaults to None.
648
+ status (str | list[str] | RunStatus | list[RunStatus] | None, optional): The
649
+ status of the runs to filter. Defaults to None.
644
650
  **kwargs: Additional key-value pairs to filter the runs.
645
651
 
646
652
  Returns:
647
653
  A list of runs that match the specified configuration and key-value pairs.
648
654
 
649
655
  """
656
+ if override:
657
+ config = select_overrides(config)
658
+ elif select:
659
+ config = select_config(config, select)
660
+
650
661
  for key, value in chain(iter_params(config), kwargs.items()):
651
662
  runs = [run for run in runs if _param_matches(run, key, value)]
652
663
  if not runs:
hydraflow/utils.py CHANGED
@@ -90,7 +90,7 @@ def load_config(run: Run) -> DictConfig:
90
90
 
91
91
  def get_overrides() -> list[str]:
92
92
  """Retrieve the overrides for the current run."""
93
- return HydraConfig.get().overrides.task
93
+ return list(HydraConfig.get().overrides.task) # ListConifg -> list
94
94
 
95
95
 
96
96
  def load_overrides(run: Run) -> list[str]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.3.4
3
+ Version: 0.3.6
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -1,16 +1,16 @@
1
- hydraflow/__init__.py,sha256=HVVsSRK2BXtHdoueFNRva924th49_lCV5zYb2c8qdaw,811
1
+ hydraflow/__init__.py,sha256=VbrHKs2Cg93QJ8K9WHYxkXmzOpb8o9ugiwV-mXDT0JE,908
2
2
  hydraflow/asyncio.py,sha256=-i1C8KAmNDImrjHnk92Csaa1mpjdK8Vp4ZVaQV-l94s,6634
3
- hydraflow/config.py,sha256=dRG4cFqDH0vBX109q0C46jiXdvbYoYgu651D6KlmmxQ,3021
3
+ hydraflow/config.py,sha256=MNX9da5bPVDcjnpji7Cm9ndK6ura92pt361m4PRh6_E,4326
4
4
  hydraflow/context.py,sha256=p1UYHvSCPrp10cBn9TUI9mXMv0h_I0Eou24Wp1rZZ0k,8740
5
5
  hydraflow/mlflow.py,sha256=JELqXFCJ9MsEJaQWT5dyleTFk8BHL7cQwW_gzhkPoIg,8729
6
6
  hydraflow/param.py,sha256=CO-6PRlnHo-7hlY_P6j_cGlC7vPY6t-Rr7p3OqeqDyU,1995
7
7
  hydraflow/progress.py,sha256=zvKX1HCN8_xDOsgYOEcLLhkhdPdep-U8vHrc0XZ-6SQ,6163
8
8
  hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- hydraflow/run_collection.py,sha256=OiPibp8zM4gpsvqpHr9sBh4_1zmHRuXfPcmen7xND-s,24792
9
+ hydraflow/run_collection.py,sha256=qz74Ct97Fa1FTFM742NlcWGW6NjxuDs5sQitJ8ijEJY,25294
10
10
  hydraflow/run_data.py,sha256=qeFX1iRvNAorXA9QQIjzr0o2_82TI44eZKp7llKG8GI,1549
11
11
  hydraflow/run_info.py,sha256=sMXOo20ClaRIommMEzuAbO_OrcXx7M1Yt4FMV7spxz0,998
12
- hydraflow/utils.py,sha256=qNN0JDbQJweTkcRMpZwMXTIuK_LVpoDJ2rOwfe01f3U,3500
13
- hydraflow-0.3.4.dist-info/METADATA,sha256=DJuHernRtS9IjpWRqCIt-yDXdFBDioSf3f3-T15D-yI,3840
14
- hydraflow-0.3.4.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
- hydraflow-0.3.4.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
- hydraflow-0.3.4.dist-info/RECORD,,
12
+ hydraflow/utils.py,sha256=gNI0Ln2VBHfMBzNB9SNxJfjCLf14irYt0EBeeMXMeyk,3528
13
+ hydraflow-0.3.6.dist-info/METADATA,sha256=Vn7lEY8dSfvK5V7q0-TVx7cM1AJo2mZCc4UyE3uOElo,3840
14
+ hydraflow-0.3.6.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
+ hydraflow-0.3.6.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
+ hydraflow-0.3.6.dist-info/RECORD,,