hydraflow 0.4.0__tar.gz → 0.4.2__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (46) hide show
  1. hydraflow-0.4.2/.devcontainer/devcontainer.json +18 -0
  2. hydraflow-0.4.2/.devcontainer/postCreate.sh +10 -0
  3. {hydraflow-0.4.0 → hydraflow-0.4.2}/PKG-INFO +1 -1
  4. {hydraflow-0.4.0 → hydraflow-0.4.2}/pyproject.toml +1 -1
  5. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/context.py +1 -0
  6. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/mlflow.py +1 -0
  7. hydraflow-0.4.2/src/hydraflow/param.py +162 -0
  8. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/run_collection.py +68 -47
  9. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/utils.py +1 -0
  10. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/integ/app.py +12 -4
  11. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/scripts/app.py +13 -0
  12. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_app.py +34 -0
  13. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_asyncio.py +55 -0
  14. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_param.py +39 -0
  15. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_run_collection.py +7 -1
  16. hydraflow-0.4.0/.devcontainer/devcontainer.json +0 -17
  17. hydraflow-0.4.0/.devcontainer/postCreate.sh +0 -5
  18. hydraflow-0.4.0/src/hydraflow/param.py +0 -78
  19. {hydraflow-0.4.0 → hydraflow-0.4.2}/.devcontainer/starship.toml +0 -0
  20. {hydraflow-0.4.0 → hydraflow-0.4.2}/.gitattributes +0 -0
  21. {hydraflow-0.4.0 → hydraflow-0.4.2}/.gitignore +0 -0
  22. {hydraflow-0.4.0 → hydraflow-0.4.2}/LICENSE +0 -0
  23. {hydraflow-0.4.0 → hydraflow-0.4.2}/README.md +0 -0
  24. {hydraflow-0.4.0 → hydraflow-0.4.2}/apps/quickstart.py +0 -0
  25. {hydraflow-0.4.0 → hydraflow-0.4.2}/mkdocs.yml +0 -0
  26. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/__init__.py +0 -0
  27. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/asyncio.py +0 -0
  28. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/config.py +0 -0
  29. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/progress.py +0 -0
  30. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/py.typed +0 -0
  31. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/run_data.py +0 -0
  32. {hydraflow-0.4.0 → hydraflow-0.4.2}/src/hydraflow/run_info.py +0 -0
  33. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/__init__.py +0 -0
  34. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/integ/__init__.py +0 -0
  35. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/scripts/__init__.py +0 -0
  36. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/scripts/progress.py +0 -0
  37. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/scripts/watch.py +0 -0
  38. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_config.py +0 -0
  39. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_context.py +0 -0
  40. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_log_run.py +0 -0
  41. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_mlflow.py +0 -0
  42. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_progress.py +0 -0
  43. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_run_data.py +0 -0
  44. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_run_info.py +0 -0
  45. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_version.py +0 -0
  46. {hydraflow-0.4.0 → hydraflow-0.4.2}/tests/test_watch.py +0 -0
@@ -0,0 +1,18 @@
1
+ {
2
+ "image": "mcr.microsoft.com/vscode/devcontainers/base:ubuntu24.04",
3
+ "features": {
4
+ "ghcr.io/devcontainers-contrib/features/starship:1": {}
5
+ },
6
+ "customizations": {
7
+ "vscode": {
8
+ "extensions": [
9
+ "charliermarsh.ruff",
10
+ "fill-labs.dependi",
11
+ "ms-python.python",
12
+ "ms-python.vscode-pylance",
13
+ "tamasfe.even-better-toml"
14
+ ]
15
+ }
16
+ },
17
+ "postCreateCommand": ".devcontainer/postCreate.sh"
18
+ }
@@ -0,0 +1,10 @@
1
+ #!/bin/bash
2
+
3
+ echo 'eval "$(starship init bash)"' >> ~/.bashrc
4
+ echo "alias ll='ls -alF'" >> ~/.bashrc
5
+ mkdir -p ~/.config
6
+ cp .devcontainer/starship.toml ~/.config
7
+
8
+ curl -LsSf https://astral.sh/uv/install.sh | sh
9
+ source $HOME/.cargo/env
10
+ echo 'eval "$(uv generate-shell-completion bash)"' >> ~/.bashrc
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.4.0"
7
+ version = "0.4.2"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -10,6 +10,7 @@ from pathlib import Path
10
10
  from typing import TYPE_CHECKING
11
11
 
12
12
  import mlflow
13
+ import mlflow.artifacts
13
14
  from hydra.core.hydra_config import HydraConfig
14
15
  from watchdog.events import FileModifiedEvent, PatternMatchingEventHandler
15
16
  from watchdog.observers import Observer
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
21
21
 
22
22
  import joblib
23
23
  import mlflow
24
+ import mlflow.artifacts
24
25
  from hydra.core.hydra_config import HydraConfig
25
26
  from mlflow.entities import ViewType
26
27
  from mlflow.tracking.fluent import SEARCH_MAX_RESULTS_PANDAS, _get_experiment_id
@@ -0,0 +1,162 @@
1
+ """Provide utility functions for parameter matching.
2
+
3
+ The main function `match` checks if a given parameter matches a specified value.
4
+ It supports various types of values including None, boolean, list, tuple, int,
5
+ float, and str.
6
+
7
+ Helper functions `_match_list` and `_match_tuple` are used internally to handle
8
+ matching for list and tuple types respectively.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import TYPE_CHECKING, Any
14
+
15
+ from omegaconf import ListConfig, OmegaConf
16
+
17
+ if TYPE_CHECKING:
18
+ from mlflow.entities import Run
19
+
20
+
21
+ def match(param: str, value: Any) -> bool:
22
+ """Check if the string matches the specified value.
23
+
24
+ Args:
25
+ param (str): The parameter to check.
26
+ value (Any): The value to check.
27
+
28
+ Returns:
29
+ True if the parameter matches the specified value,
30
+ False otherwise.
31
+
32
+ """
33
+ if any(value is x for x in [None, True, False]):
34
+ return param == str(value)
35
+
36
+ if isinstance(value, list) and (m := _match_list(param, value)) is not None:
37
+ return m
38
+
39
+ if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
40
+ return m
41
+
42
+ if isinstance(value, int | float):
43
+ return float(param) == value
44
+
45
+ if isinstance(value, str):
46
+ return param == value
47
+
48
+ return param == str(value)
49
+
50
+
51
+ def _match_list(param: str, value: list) -> bool | None:
52
+ if not value:
53
+ return None
54
+
55
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
56
+ return None
57
+
58
+ if isinstance(value[0], bool):
59
+ return None
60
+
61
+ if not isinstance(value[0], int | float | str):
62
+ return None
63
+
64
+ return type(value[0])(param) in value
65
+
66
+
67
+ def _match_tuple(param: str, value: tuple) -> bool | None:
68
+ if len(value) != 2: # noqa: PLR2004
69
+ return None
70
+
71
+ if any(param.startswith(x) for x in ["[", "(", "{"]):
72
+ return None
73
+
74
+ if isinstance(value[0], bool):
75
+ return None
76
+
77
+ if not isinstance(value[0], int | float | str):
78
+ return None
79
+
80
+ if type(value[0]) is not type(value[1]):
81
+ return None
82
+
83
+ return value[0] <= type(value[0])(param) <= value[1] # type: ignore
84
+
85
+
86
+ def to_value(param: str | None, type_: type) -> Any:
87
+ """Convert the parameter to the specified type.
88
+
89
+ Args:
90
+ param (str | None): The parameter to convert.
91
+ type_ (type): The type to convert to.
92
+
93
+ Returns:
94
+ The converted value.
95
+
96
+ """
97
+ if param is None or param == "None":
98
+ return None
99
+
100
+ if type_ is int:
101
+ return int(param)
102
+
103
+ if type_ is float:
104
+ return float(param)
105
+
106
+ if type_ is bool:
107
+ return param == "True"
108
+
109
+ if type_ is list or type_ is ListConfig:
110
+ return list(OmegaConf.create(param))
111
+
112
+ return param
113
+
114
+
115
+ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
116
+ """Retrieve the values of specified parameters from the given run.
117
+
118
+ This function extracts the values of the parameters identified by the
119
+ provided names from the specified run. It can accept both individual
120
+ parameter names and lists of parameter names.
121
+
122
+ Args:
123
+ run (Run): The run object from which to extract parameter values.
124
+ *names (str | list[str]): The names of the parameters to retrieve.
125
+ This can be a single parameter name or multiple names provided
126
+ as separate arguments or as a list.
127
+
128
+ Returns:
129
+ tuple[str | None, ...]: A tuple containing the values of the specified
130
+ parameters in the order they were provided.
131
+
132
+ """
133
+ names_ = []
134
+ for name in names:
135
+ if isinstance(name, list):
136
+ names_.extend(name)
137
+ else:
138
+ names_.append(name)
139
+
140
+ params = run.data.params
141
+ return tuple(params.get(name) for name in names_)
142
+
143
+
144
+ def get_values(run: Run, names: list[str], types: list[type]) -> tuple[Any, ...]:
145
+ """Retrieve the values of specified parameters from the given run.
146
+
147
+ This function extracts the values of the parameters identified by the
148
+ provided names from the specified run.
149
+
150
+ Args:
151
+ run (Run): The run object from which to extract parameter values.
152
+ names (list[str]): The names of the parameters to retrieve.
153
+ types (list[type]): The types to convert to.
154
+
155
+ Returns:
156
+ tuple[Any, ...]: A tuple containing the values of the specified
157
+ parameters in the order they were provided.
158
+
159
+ """
160
+ params = get_params(run, names)
161
+ it = zip(params, types, strict=True)
162
+ return tuple(to_value(param, type_) for param, type_ in it)
@@ -27,6 +27,7 @@ from mlflow.entities import RunStatus
27
27
 
28
28
  import hydraflow.param
29
29
  from hydraflow.config import iter_params, select_config, select_overrides
30
+ from hydraflow.param import get_params, get_values
30
31
  from hydraflow.run_data import RunCollectionData
31
32
  from hydraflow.run_info import RunCollectionInfo
32
33
  from hydraflow.utils import load_config
@@ -132,25 +133,6 @@ class RunCollection:
132
133
 
133
134
  return self.__class__(self._runs[:n])
134
135
 
135
- def sort(
136
- self,
137
- key: Callable[[Run], Any] | None = None,
138
- *,
139
- reverse: bool = False,
140
- ) -> None:
141
- """Sort the runs in the collection.
142
-
143
- Sort the runs in the collection according to the provided key function
144
- and optional reverse flag.
145
-
146
- Args:
147
- key (Callable[[Run], Any] | None): A function that takes a run and returns
148
- a value to sort by.
149
- reverse (bool): If True, sort in descending order.
150
-
151
- """
152
- self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
153
-
154
136
  def one(self) -> Run:
155
137
  """Get the only `Run` instance in the collection.
156
138
 
@@ -599,6 +581,73 @@ class RunCollection:
599
581
 
600
582
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
601
583
 
584
+ def sort(
585
+ self,
586
+ key: Callable[[Run], Any] | None = None,
587
+ *,
588
+ reverse: bool = False,
589
+ ) -> None:
590
+ """Sort the runs in the collection.
591
+
592
+ Sort the runs in the collection according to the provided key function
593
+ and optional reverse flag.
594
+
595
+ Args:
596
+ key (Callable[[Run], Any] | None): A function that takes a run and returns
597
+ a value to sort by.
598
+ reverse (bool): If True, sort in descending order.
599
+
600
+ """
601
+ self._runs.sort(key=key or (lambda x: x.info.start_time), reverse=reverse)
602
+
603
+ def values(self, names: str | list[str]) -> list[Any]:
604
+ """Get the values of specified parameters from the runs.
605
+
606
+ Args:
607
+ names (str | list[str]): The names of the parameters to get the values.
608
+ This can be a single parameter name or multiple names provided
609
+ as separate arguments or as a list.
610
+
611
+ Returns:
612
+ A list of values for the specified parameters.
613
+
614
+ """
615
+ is_list = isinstance(names, list)
616
+
617
+ if isinstance(names, str):
618
+ names = [names]
619
+
620
+ config = load_config(self.first())
621
+ types = [type(v) for v in select_config(config, names).values()]
622
+ values = [get_values(run, names, types) for run in self]
623
+
624
+ if is_list:
625
+ return values
626
+
627
+ return [v[0] for v in values]
628
+
629
+ def sort_by(
630
+ self,
631
+ names: str | list[str],
632
+ *,
633
+ reverse: bool = False,
634
+ ) -> RunCollection:
635
+ """Sort the runs in the collection by specified parameter names.
636
+
637
+ Sort the runs in the collection based on the values of the specified
638
+ parameters.
639
+
640
+ Args:
641
+ names (str | list[str]): The names of the parameters to sort by.
642
+ This can be a single parameter name or multiple names provided
643
+ as separate arguments or as a list.
644
+ reverse (bool): If True, sort in descending order.
645
+
646
+ """
647
+ values = self.values(names)
648
+ index = sorted(range(len(self)), key=lambda i: values[i], reverse=reverse)
649
+ return RunCollection([self[i] for i in index])
650
+
602
651
 
603
652
  def _param_matches(run: Run, key: str, value: Any) -> bool:
604
653
  params = run.data.params
@@ -703,31 +752,3 @@ def _to_lower(status: str | int) -> str:
703
752
  return status.lower()
704
753
 
705
754
  return RunStatus.to_string(status).lower()
706
-
707
-
708
- def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
709
- """Retrieve the values of specified parameters from the given run.
710
-
711
- This function extracts the values of the parameters identified by the
712
- provided names from the specified run. It can accept both individual
713
- parameter names and lists of parameter names.
714
-
715
- Args:
716
- run (Run): The run object from which to extract parameter values.
717
- *names (str | list[str]): The names of the parameters to retrieve.
718
- This can be a single parameter name or multiple names provided
719
- as separate arguments or as a list.
720
-
721
- Returns:
722
- tuple[str | None, ...]: A tuple containing the values of the specified
723
- parameters in the order they were provided.
724
-
725
- """
726
- names_ = []
727
- for name in names:
728
- if isinstance(name, list):
729
- names_.extend(name)
730
- else:
731
- names_.append(name)
732
-
733
- return tuple(run.data.params.get(name) for name in names_)
@@ -6,6 +6,7 @@ from pathlib import Path
6
6
  from typing import TYPE_CHECKING
7
7
 
8
8
  import mlflow
9
+ import mlflow.artifacts
9
10
  from hydra.core.hydra_config import HydraConfig
10
11
  from mlflow.entities import Run
11
12
  from mlflow.tracking import artifact_utils
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  from dataclasses import dataclass, field
5
+ from enum import Enum, auto
5
6
 
6
7
  import hydra
7
8
  from hydra.core.config_store import ConfigStore
@@ -11,6 +12,11 @@ import hydraflow
11
12
  log = logging.getLogger(__name__)
12
13
 
13
14
 
15
+ class E(Enum):
16
+ A = auto()
17
+ B = auto()
18
+
19
+
14
20
  @dataclass
15
21
  class B:
16
22
  z: float = 0.0
@@ -20,6 +26,7 @@ class B:
20
26
  class A:
21
27
  y: str = "y"
22
28
  b: B = field(default_factory=B)
29
+ e: E = E.A
23
30
 
24
31
 
25
32
  @dataclass
@@ -38,12 +45,13 @@ def app(cfg: Config):
38
45
  hydraflow.set_experiment()
39
46
  rc = hydraflow.list_runs()
40
47
  log.info(rc)
48
+ log.info(cfg)
49
+ log.info(hydraflow.get_overrides())
41
50
  log.info(hydraflow.select_overrides(cfg))
42
51
  log.info(rc.filter(cfg, override=True))
43
- log.info(rc.filter(cfg, select=["x"]))
44
- log.info(rc.try_find_last(cfg, override=True))
45
- log.info(rc.try_find_last(cfg, select=["x"]))
46
- log.info(rc.filter(cfg))
52
+ for r in rc:
53
+ log.info(r.data.params)
54
+ log.info(hydraflow.load_config(r))
47
55
 
48
56
  cfg.y = 2 * cfg.x
49
57
  with hydraflow.start_run(cfg):
@@ -30,6 +30,16 @@ def app(cfg: MySQLConfig):
30
30
  with hydraflow.chdir_hydra_output() as path:
31
31
  Path("chdir_hydra.txt").write_text(path.as_posix())
32
32
 
33
+ o = hydraflow.select_overrides(cfg)
34
+ if "host" in o:
35
+ assert o["host"] == cfg.host
36
+
37
+ if "port" not in o:
38
+ assert cfg.port == 3306
39
+
40
+ if "values" not in o:
41
+ assert cfg.get("values") == [1, 2, 3] # type: ignore
42
+
33
43
  hydraflow.set_experiment(prefix="_", suffix="_")
34
44
  with hydraflow.start_run(cfg) as run:
35
45
  log.info(f"START, {cfg.host}, {cfg.port} ")
@@ -52,6 +62,9 @@ def app(cfg: MySQLConfig):
52
62
 
53
63
  assert hydraflow.get_overrides() == hydraflow.load_overrides(run)
54
64
 
65
+ if cfg.host == "error":
66
+ raise Exception("error")
67
+
55
68
  log.info("END")
56
69
 
57
70
 
@@ -179,3 +179,37 @@ def test_app_filter_list(rc: RunCollection):
179
179
  assert len(filtered) == 4
180
180
  filtered = rc.filter(values=[1])
181
181
  assert not filtered
182
+
183
+
184
+ def test_values(rc: RunCollection):
185
+ values = rc.values("host")
186
+ assert values == ["x", "x", "y", "y"]
187
+ values = rc.values(["host", "port"])
188
+ assert values == [("x", 1), ("x", 2), ("y", 1), ("y", 2)]
189
+
190
+
191
+ def test_sort_by(rc: RunCollection):
192
+ sorted = rc.sort_by("host", reverse=True)
193
+ assert sorted.values(["host", "port"]) == [("y", 1), ("y", 2), ("x", 1), ("x", 2)]
194
+
195
+ sorted = rc.sort_by(["host", "port"], reverse=True)
196
+ assert sorted.values(["host", "port"]) == [("y", 2), ("y", 1), ("x", 2), ("x", 1)]
197
+
198
+
199
+ def test_log_run_error(monkeypatch, tmp_path):
200
+ file = Path("tests/scripts/app.py").absolute()
201
+ monkeypatch.chdir(tmp_path)
202
+
203
+ args = [sys.executable, file.as_posix()]
204
+ args += ["host=error", "hydra.job.name=error"]
205
+ cp = subprocess.run(args, check=False, capture_output=True)
206
+ assert cp.returncode == 1
207
+ assert b"Error during log_run: error" in cp.stdout
208
+
209
+
210
+ def test_chdir_artifact(rc: RunCollection):
211
+ from hydraflow.context import chdir_artifact
212
+
213
+ with chdir_artifact(rc[0]):
214
+ assert Path.cwd().stem == "artifacts"
215
+ assert Path.cwd().parent.stem == rc[0].info.run_id
@@ -164,3 +164,58 @@ def test_run(tmp_path: Path):
164
164
  assert stderr_lines == ["world"]
165
165
  assert Path(path).read_text() == "hello world"
166
166
  assert len(changes_detected) >= 2
167
+
168
+
169
+ @pytest.mark.asyncio
170
+ async def test_execute_command_nonexistent():
171
+ from hydraflow.asyncio import execute_command
172
+
173
+ stop_event = asyncio.Event()
174
+
175
+ rc = await execute_command("nonexistent_command", stop_event=stop_event)
176
+ assert rc == 1
177
+ assert stop_event.is_set()
178
+
179
+
180
+ @pytest.mark.asyncio
181
+ async def test_process_stream_none():
182
+ from hydraflow.asyncio import process_stream
183
+
184
+ assert await process_stream(None, None) is None
185
+
186
+
187
+ @pytest.mark.asyncio
188
+ async def test_monitor_file_changes_error():
189
+ from hydraflow.asyncio import monitor_file_changes
190
+
191
+ stop_event = asyncio.Event()
192
+
193
+ with pytest.raises(FileNotFoundError):
194
+ await monitor_file_changes(["nonexistent_path"], lambda _: None, stop_event)
195
+
196
+
197
+ @pytest.mark.asyncio
198
+ async def test_run_and_monitor_none():
199
+ from hydraflow.asyncio import run_and_monitor
200
+
201
+ assert await run_and_monitor("echo", "hello") == 0
202
+
203
+
204
+ @pytest.mark.asyncio
205
+ async def test_run_and_monitor_error():
206
+ from hydraflow.asyncio import run_and_monitor
207
+
208
+ with pytest.raises(FileNotFoundError):
209
+ await run_and_monitor(
210
+ "nonexistent_command",
211
+ watch=lambda _: None,
212
+ paths=["nonexistent_path"],
213
+ )
214
+
215
+
216
+ def test_run_cwd():
217
+ from hydraflow.asyncio import run
218
+
219
+ return_code = run(sys.executable, "--version", watch=lambda _: None)
220
+
221
+ assert return_code == 0
@@ -49,6 +49,22 @@ def test_param(param, x, y):
49
49
  assert match(p, x)
50
50
 
51
51
 
52
+ def test_param_float():
53
+ from hydraflow.param import match
54
+
55
+ assert match("1.0", 1.0)
56
+ assert match("1.0", 1)
57
+ assert match("0.0", 0)
58
+ assert match("0.0", 0.0)
59
+
60
+
61
+ def test_param_bool():
62
+ from hydraflow.param import match
63
+
64
+ assert not match("1", True)
65
+ assert not match("0", False)
66
+
67
+
52
68
  def test_match_list():
53
69
  from hydraflow.param import _match_list
54
70
 
@@ -76,3 +92,26 @@ def test_match_tuple():
76
92
  assert _match_tuple("1", (True, False)) is None
77
93
  assert _match_tuple("1", (None, None)) is None
78
94
  assert _match_tuple("1", (1, 3.2)) is None
95
+
96
+
97
+ def test_to_value():
98
+ from hydraflow.param import to_value
99
+
100
+ assert to_value("1", int) == 1
101
+ assert to_value("1", float) == 1.0
102
+ assert to_value("1.0", float) == 1.0
103
+ assert to_value("True", bool) is True
104
+ assert to_value("False", bool) is False
105
+ assert to_value("None", int) is None
106
+ assert to_value("a", str) == "a"
107
+
108
+
109
+ def test_to_value_list():
110
+ from hydraflow.param import to_value
111
+
112
+ x = "[1, 2, 3]"
113
+ assert to_value(x, list) == [1, 2, 3]
114
+ x = "[1.2, 2.3, 3.4]"
115
+ assert to_value(x, list) == [1.2, 2.3, 3.4]
116
+ x = "[a, b, c]"
117
+ assert to_value(x, list) == ["a", "b", "c"]
@@ -126,7 +126,7 @@ def test_filter_status_enum(run_list: list[Run]):
126
126
 
127
127
 
128
128
  def test_get_params(run_list: list[Run]):
129
- from hydraflow.run_collection import get_params
129
+ from hydraflow.param import get_params
130
130
 
131
131
  assert get_params(run_list[1], "p") == ("1",)
132
132
  assert get_params(run_list[2], "p", "q") == ("2", "0")
@@ -135,6 +135,12 @@ def test_get_params(run_list: list[Run]):
135
135
  assert get_params(run_list[5], ["a", "q"], "r") == (None, "None", "2")
136
136
 
137
137
 
138
+ def test_get_values(run_list: list[Run]):
139
+ from hydraflow.param import get_values
140
+
141
+ assert get_values(run_list[3], ["p", "q"], [int, int]) == (3, 0)
142
+
143
+
138
144
  @pytest.mark.parametrize("i", range(6))
139
145
  def test_chdir_artifact_list(i: int, run_list: list[Run]):
140
146
  from hydraflow.context import chdir_artifact
@@ -1,17 +0,0 @@
1
- {
2
- "image": "mcr.microsoft.com/vscode/devcontainers/python:3.12",
3
- "features": {
4
- "ghcr.io/devcontainers-contrib/features/starship:1": {},
5
- "ghcr.io/va-h/devcontainers-features/uv:1": {}
6
- },
7
- "customizations": {
8
- "vscode": {
9
- "extensions": [
10
- "charliermarsh.ruff",
11
- "ms-python.python",
12
- "ms-python.vscode-pylance"
13
- ]
14
- }
15
- },
16
- "postCreateCommand": ".devcontainer/postCreate.sh"
17
- }
@@ -1,5 +0,0 @@
1
- #!/bin/sh
2
-
3
- echo 'eval "$(starship init bash)"' >> ~/.bashrc
4
- mkdir -p ~/.config
5
- cp .devcontainer/starship.toml ~/.config
@@ -1,78 +0,0 @@
1
- """Provide utility functions for parameter matching.
2
-
3
- The main function `match` checks if a given parameter matches a specified value.
4
- It supports various types of values including None, boolean, list, tuple, int,
5
- float, and str.
6
-
7
- Helper functions `_match_list` and `_match_tuple` are used internally to handle
8
- matching for list and tuple types respectively.
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- from typing import Any
14
-
15
-
16
- def match(param: str, value: Any) -> bool:
17
- """Check if the string matches the specified value.
18
-
19
- Args:
20
- param (str): The parameter to check.
21
- value (Any): The value to check.
22
-
23
- Returns:
24
- True if the parameter matches the specified value,
25
- False otherwise.
26
-
27
- """
28
- if value in [None, True, False]:
29
- return param == str(value)
30
-
31
- if isinstance(value, list) and (m := _match_list(param, value)) is not None:
32
- return m
33
-
34
- if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
35
- return m
36
-
37
- if isinstance(value, str):
38
- return param == value
39
-
40
- if isinstance(value, int | float):
41
- return type(value)(param) == value
42
-
43
- return param == str(value)
44
-
45
-
46
- def _match_list(param: str, value: list) -> bool | None:
47
- if not value:
48
- return None
49
-
50
- if any(param.startswith(x) for x in ["[", "(", "{"]):
51
- return None
52
-
53
- if isinstance(value[0], bool):
54
- return None
55
-
56
- if not isinstance(value[0], int | float | str):
57
- return None
58
-
59
- return type(value[0])(param) in value
60
-
61
-
62
- def _match_tuple(param: str, value: tuple) -> bool | None:
63
- if len(value) != 2: # noqa: PLR2004
64
- return None
65
-
66
- if any(param.startswith(x) for x in ["[", "(", "{"]):
67
- return None
68
-
69
- if isinstance(value[0], bool):
70
- return None
71
-
72
- if not isinstance(value[0], int | float | str):
73
- return None
74
-
75
- if type(value[0]) is not type(value[1]):
76
- return None
77
-
78
- return value[0] <= type(value[0])(param) <= value[1] # type: ignore
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes