hydraflow 0.3.5__py3-none-any.whl → 0.4.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
hydraflow/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
- from .config import select_config
3
+ from .config import select_config, select_overrides
4
4
  from .context import chdir_artifact, chdir_hydra_output, log_run, start_run, watch
5
5
  from .mlflow import list_runs, search_runs, set_experiment
6
6
  from .progress import multi_tasks_progress, parallel_progress
@@ -28,6 +28,7 @@ __all__ = [
28
28
  "parallel_progress",
29
29
  "search_runs",
30
30
  "select_config",
31
+ "select_overrides",
31
32
  "set_experiment",
32
33
  "start_run",
33
34
  "watch",
hydraflow/config.py CHANGED
@@ -6,6 +6,8 @@ from typing import TYPE_CHECKING
6
6
 
7
7
  from omegaconf import DictConfig, ListConfig, OmegaConf
8
8
 
9
+ from hydraflow.utils import get_overrides
10
+
9
11
  if TYPE_CHECKING:
10
12
  from collections.abc import Iterator
11
13
  from typing import Any
@@ -116,9 +118,9 @@ def select_config(config: object, names: list[str]) -> dict[str, Any]:
116
118
 
117
119
  """
118
120
  if not isinstance(config, DictConfig):
119
- cfg = OmegaConf.structured(config)
121
+ config = OmegaConf.structured(config)
120
122
 
121
- return {name: _get(cfg, name) for name in names}
123
+ return {name: _get(config, name) for name in names} # type: ignore
122
124
 
123
125
 
124
126
  def _get(config: DictConfig, name: str) -> Any:
@@ -128,3 +130,10 @@ def _get(config: DictConfig, name: str) -> Any:
128
130
 
129
131
  prefix, name = name.split(".", 1)
130
132
  return _get(config.get(prefix), name)
133
+
134
+
135
+ def select_overrides(config: object) -> dict[str, Any]:
136
+ """Select the given overrides from the configuration object."""
137
+ overrides = get_overrides()
138
+ names = [override.split("=")[0].strip() for override in overrides]
139
+ return select_config(config, names)
@@ -26,7 +26,7 @@ from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
26
26
  from mlflow.entities import RunStatus
27
27
 
28
28
  import hydraflow.param
29
- from hydraflow.config import iter_params
29
+ from hydraflow.config import iter_params, select_config, select_overrides
30
30
  from hydraflow.run_data import RunCollectionData
31
31
  from hydraflow.run_info import RunCollectionInfo
32
32
  from hydraflow.utils import load_config
@@ -603,7 +603,7 @@ class RunCollection:
603
603
  def _param_matches(run: Run, key: str, value: Any) -> bool:
604
604
  params = run.data.params
605
605
  if key not in params:
606
- return True
606
+ return False
607
607
 
608
608
  param = params[key]
609
609
  if param == "None":
@@ -616,6 +616,8 @@ def filter_runs(
616
616
  runs: list[Run],
617
617
  config: object | None = None,
618
618
  *,
619
+ override: bool = False,
620
+ select: list[str] | None = None,
619
621
  status: str | list[str] | int | list[int] | None = None,
620
622
  **kwargs,
621
623
  ) -> list[Run]:
@@ -636,17 +638,26 @@ def filter_runs(
636
638
 
637
639
  Args:
638
640
  runs (list[Run]): The list of runs to filter.
639
- config (object | None): The configuration object to filter the runs.
640
- This can be any object that provides key-value pairs through the
641
- `iter_params` function.
642
- status (str | list[str] | RunStatus | list[RunStatus] | None): The status of
643
- the runs to filter.
641
+ config (object | None, optional): The configuration object to filter the
642
+ runs. This can be any object that provides key-value pairs through
643
+ the `iter_params` function. Defaults to None.
644
+ override (bool, optional): If True, filter the runs based on
645
+ the overrides. Defaults to False.
646
+ select (list[str] | None, optional): The list of parameters to select.
647
+ Defaults to None.
648
+ status (str | list[str] | RunStatus | list[RunStatus] | None, optional): The
649
+ status of the runs to filter. Defaults to None.
644
650
  **kwargs: Additional key-value pairs to filter the runs.
645
651
 
646
652
  Returns:
647
653
  A list of runs that match the specified configuration and key-value pairs.
648
654
 
649
655
  """
656
+ if override:
657
+ config = select_overrides(config)
658
+ elif select:
659
+ config = select_config(config, select)
660
+
650
661
  for key, value in chain(iter_params(config), kwargs.items()):
651
662
  runs = [run for run in runs if _param_matches(run, key, value)]
652
663
  if not runs:
hydraflow/utils.py CHANGED
@@ -90,7 +90,7 @@ def load_config(run: Run) -> DictConfig:
90
90
 
91
91
  def get_overrides() -> list[str]:
92
92
  """Retrieve the overrides for the current run."""
93
- return HydraConfig.get().overrides.task
93
+ return list(HydraConfig.get().overrides.task) # ListConifg -> list
94
94
 
95
95
 
96
96
  def load_overrides(run: Run) -> list[str]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.3.5
3
+ Version: 0.4.0
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -1,16 +1,16 @@
1
- hydraflow/__init__.py,sha256=p0GzgI2RYXuaBdgoDgmihkn2esUlcbPJlSwoXBdbZEw,866
1
+ hydraflow/__init__.py,sha256=VbrHKs2Cg93QJ8K9WHYxkXmzOpb8o9ugiwV-mXDT0JE,908
2
2
  hydraflow/asyncio.py,sha256=-i1C8KAmNDImrjHnk92Csaa1mpjdK8Vp4ZVaQV-l94s,6634
3
- hydraflow/config.py,sha256=mQYEW_-kr-5cz64hGxI_gkK8XK_54M-3DX-YlE44A9A,3992
3
+ hydraflow/config.py,sha256=MNX9da5bPVDcjnpji7Cm9ndK6ura92pt361m4PRh6_E,4326
4
4
  hydraflow/context.py,sha256=p1UYHvSCPrp10cBn9TUI9mXMv0h_I0Eou24Wp1rZZ0k,8740
5
5
  hydraflow/mlflow.py,sha256=JELqXFCJ9MsEJaQWT5dyleTFk8BHL7cQwW_gzhkPoIg,8729
6
6
  hydraflow/param.py,sha256=CO-6PRlnHo-7hlY_P6j_cGlC7vPY6t-Rr7p3OqeqDyU,1995
7
7
  hydraflow/progress.py,sha256=zvKX1HCN8_xDOsgYOEcLLhkhdPdep-U8vHrc0XZ-6SQ,6163
8
8
  hydraflow/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
- hydraflow/run_collection.py,sha256=OiPibp8zM4gpsvqpHr9sBh4_1zmHRuXfPcmen7xND-s,24792
9
+ hydraflow/run_collection.py,sha256=-PEN8vO4beQkxhEQH9xh0_TzEIO34-eulwRt7WidrIA,25295
10
10
  hydraflow/run_data.py,sha256=qeFX1iRvNAorXA9QQIjzr0o2_82TI44eZKp7llKG8GI,1549
11
11
  hydraflow/run_info.py,sha256=sMXOo20ClaRIommMEzuAbO_OrcXx7M1Yt4FMV7spxz0,998
12
- hydraflow/utils.py,sha256=qNN0JDbQJweTkcRMpZwMXTIuK_LVpoDJ2rOwfe01f3U,3500
13
- hydraflow-0.3.5.dist-info/METADATA,sha256=UVUeGwlz9ZcdqOX-tWJ3MJX1KD05i4lkPj455J2vAik,3840
14
- hydraflow-0.3.5.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
- hydraflow-0.3.5.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
- hydraflow-0.3.5.dist-info/RECORD,,
12
+ hydraflow/utils.py,sha256=gNI0Ln2VBHfMBzNB9SNxJfjCLf14irYt0EBeeMXMeyk,3528
13
+ hydraflow-0.4.0.dist-info/METADATA,sha256=w0gsff6RLwx4l8_Qlw-1SOu0dY4ySPCq2psm0THmvcY,3840
14
+ hydraflow-0.4.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
15
+ hydraflow-0.4.0.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
16
+ hydraflow-0.4.0.dist-info/RECORD,,