hydraflow 0.3.3__tar.gz → 0.3.4__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (41) hide show
  1. {hydraflow-0.3.3 → hydraflow-0.3.4}/PKG-INFO +1 -1
  2. {hydraflow-0.3.3 → hydraflow-0.3.4}/pyproject.toml +1 -1
  3. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/__init__.py +2 -2
  4. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/config.py +13 -0
  5. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/context.py +1 -1
  6. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/param.py +4 -1
  7. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/scripts/app.py +1 -1
  8. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_config.py +21 -0
  9. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_run_collection.py +17 -0
  10. {hydraflow-0.3.3 → hydraflow-0.3.4}/.devcontainer/devcontainer.json +0 -0
  11. {hydraflow-0.3.3 → hydraflow-0.3.4}/.devcontainer/postCreate.sh +0 -0
  12. {hydraflow-0.3.3 → hydraflow-0.3.4}/.devcontainer/starship.toml +0 -0
  13. {hydraflow-0.3.3 → hydraflow-0.3.4}/.gitattributes +0 -0
  14. {hydraflow-0.3.3 → hydraflow-0.3.4}/.gitignore +0 -0
  15. {hydraflow-0.3.3 → hydraflow-0.3.4}/LICENSE +0 -0
  16. {hydraflow-0.3.3 → hydraflow-0.3.4}/README.md +0 -0
  17. {hydraflow-0.3.3 → hydraflow-0.3.4}/apps/quickstart.py +0 -0
  18. {hydraflow-0.3.3 → hydraflow-0.3.4}/mkdocs.yml +0 -0
  19. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/asyncio.py +0 -0
  20. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/mlflow.py +0 -0
  21. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/progress.py +0 -0
  22. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/py.typed +0 -0
  23. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/run_collection.py +0 -0
  24. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/run_data.py +0 -0
  25. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/run_info.py +0 -0
  26. {hydraflow-0.3.3 → hydraflow-0.3.4}/src/hydraflow/utils.py +5 -5
  27. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/__init__.py +0 -0
  28. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/scripts/__init__.py +0 -0
  29. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/scripts/progress.py +0 -0
  30. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/scripts/watch.py +0 -0
  31. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_app.py +0 -0
  32. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_asyncio.py +0 -0
  33. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_context.py +0 -0
  34. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_log_run.py +0 -0
  35. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_mlflow.py +0 -0
  36. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_param.py +0 -0
  37. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_progress.py +0 -0
  38. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_run_data.py +0 -0
  39. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_run_info.py +0 -0
  40. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_version.py +0 -0
  41. {hydraflow-0.3.3 → hydraflow-0.3.4}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.3.3
3
+ Version: 0.3.4
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.3.3"
7
+ version = "0.3.4"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -1,6 +1,6 @@
1
1
  """Integrate Hydra and MLflow to manage and track machine learning experiments."""
2
2
 
3
- from .context import chdir_artifact, chdir_hydra, log_run, start_run, watch
3
+ from .context import chdir_artifact, chdir_hydra_output, log_run, start_run, watch
4
4
  from .mlflow import list_runs, search_runs, set_experiment
5
5
  from .progress import multi_tasks_progress, parallel_progress
6
6
  from .run_collection import RunCollection
@@ -15,7 +15,7 @@ from .utils import (
15
15
  __all__ = [
16
16
  "RunCollection",
17
17
  "chdir_artifact",
18
- "chdir_hydra",
18
+ "chdir_hydra_output",
19
19
  "get_artifact_dir",
20
20
  "get_hydra_output_dir",
21
21
  "get_overrides",
@@ -44,12 +44,25 @@ def iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
44
44
  if config is None:
45
45
  return
46
46
 
47
+ if isinstance(config, list) and all(isinstance(x, str) for x in config):
48
+ config = _from_dotlist(config)
49
+
47
50
  if not isinstance(config, DictConfig | ListConfig):
48
51
  config = OmegaConf.create(config) # type: ignore
49
52
 
50
53
  yield from _iter_params(config, prefix)
51
54
 
52
55
 
56
+ def _from_dotlist(config: list[str]) -> dict[str, str]:
57
+ result = {}
58
+ for item in config:
59
+ if "=" in item:
60
+ key, value = item.split("=", 1)
61
+ result[key.strip()] = value.strip()
62
+
63
+ return result
64
+
65
+
53
66
  def _iter_params(config: object, prefix: str = "") -> Iterator[tuple[str, Any]]:
54
67
  if isinstance(config, DictConfig):
55
68
  for key, value in config.items():
@@ -239,7 +239,7 @@ class Handler(PatternMatchingEventHandler):
239
239
 
240
240
 
241
241
  @contextmanager
242
- def chdir_hydra() -> Iterator[Path]:
242
+ def chdir_hydra_output() -> Iterator[Path]:
243
243
  """Change the current working directory to the hydra output directory.
244
244
 
245
245
  This context manager changes the current working directory to the hydra output
@@ -34,7 +34,10 @@ def match(param: str, value: Any) -> bool:
34
34
  if isinstance(value, tuple) and (m := _match_tuple(param, value)) is not None:
35
35
  return m
36
36
 
37
- if isinstance(value, int | float | str):
37
+ if isinstance(value, str):
38
+ return param == value
39
+
40
+ if isinstance(value, int | float):
38
41
  return type(value)(param) == value
39
42
 
40
43
  return param == str(value)
@@ -27,7 +27,7 @@ cs.store(name="config", node=MySQLConfig)
27
27
 
28
28
  @hydra.main(version_base=None, config_name="config")
29
29
  def app(cfg: MySQLConfig):
30
- with hydraflow.chdir_hydra() as path:
30
+ with hydraflow.chdir_hydra_output() as path:
31
31
  Path("chdir_hydra.txt").write_text(path.as_posix())
32
32
 
33
33
  hydraflow.set_experiment(prefix="_", suffix="_")
@@ -205,3 +205,24 @@ def test_list_config_str(s):
205
205
  assert isinstance(b, ListConfig)
206
206
  t = OmegaConf.create(json.loads(s))
207
207
  assert b == t
208
+
209
+
210
+ @pytest.mark.parametrize("x", [{"a": 1}, {"a": [1, 2, 3]}])
211
+ def test_collect_params_dict(x):
212
+ from hydraflow.config import collect_params
213
+
214
+ assert collect_params(x) == x
215
+
216
+
217
+ def test_collect_params_dict_dot():
218
+ from hydraflow.config import collect_params
219
+
220
+ assert collect_params({"a": {"b": 1}}) == {"a.b": 1}
221
+ assert collect_params({"a.b": 1}) == {"a.b": 1}
222
+
223
+
224
+ def test_collect_params_list_dot():
225
+ from hydraflow.config import collect_params
226
+
227
+ assert collect_params(["a=1"]) == {"a": "1"}
228
+ assert collect_params(["a.b=2", "c"]) == {"a.b": "2"}
@@ -67,6 +67,8 @@ def test_filter_one(run_list: list[Run]):
67
67
  assert len(x) == 1
68
68
  x = filter_runs(run_list, p=1)
69
69
  assert len(x) == 1
70
+ x = filter_runs(run_list, ["p=1"])
71
+ assert len(x) == 1
70
72
 
71
73
 
72
74
  def test_filter_all(run_list: list[Run]):
@@ -77,6 +79,8 @@ def test_filter_all(run_list: list[Run]):
77
79
  assert len(x) == 5
78
80
  x = filter_runs(run_list, q=0)
79
81
  assert len(x) == 5
82
+ x = filter_runs(run_list, ["q=0"])
83
+ assert len(x) == 5
80
84
 
81
85
 
82
86
  def test_filter_list(run_list: list[Run]):
@@ -98,6 +102,8 @@ def test_filter_invalid_param(run_list: list[Run]):
98
102
 
99
103
  x = filter_runs(run_list, {"invalid": 0})
100
104
  assert len(x) == 6
105
+ x = filter_runs(run_list, ["invalid=0"])
106
+ assert len(x) == 6
101
107
 
102
108
 
103
109
  def test_filter_status(run_list: list[Run]):
@@ -181,15 +187,20 @@ def test_filter(rc: RunCollection):
181
187
  assert len(rc.filter()) == 6
182
188
  assert len(rc.filter({})) == 6
183
189
  assert len(rc.filter({"p": 1})) == 1
190
+ assert len(rc.filter(["p=1"])) == 1
184
191
  assert len(rc.filter({"q": 0})) == 5
192
+ assert len(rc.filter(["q=0"])) == 5
185
193
  assert len(rc.filter({"q": -1})) == 0
194
+ assert len(rc.filter(["q=-1"])) == 0
186
195
  assert not rc.filter({"q": -1})
187
196
  assert len(rc.filter(p=5)) == 1
188
197
  assert len(rc.filter(q=0)) == 5
189
198
  assert len(rc.filter(q=-1)) == 0
190
199
  assert not rc.filter(q=-1)
191
200
  assert len(rc.filter({"r": 2})) == 2
201
+ assert len(rc.filter(["r=2"])) == 2
192
202
  assert len(rc.filter(r=0)) == 2
203
+ assert len(rc.filter(["r=0"])) == 2
193
204
 
194
205
 
195
206
  def test_get(rc: RunCollection):
@@ -197,15 +208,21 @@ def test_get(rc: RunCollection):
197
208
  assert isinstance(run, Run)
198
209
  run = rc.get(p=2)
199
210
  assert isinstance(run, Run)
211
+ run = rc.get(["p=3"])
212
+ assert isinstance(run, Run)
200
213
 
201
214
 
202
215
  def test_try_get(rc: RunCollection):
203
216
  run = rc.try_get({"p": 5})
204
217
  assert isinstance(run, Run)
218
+ run = rc.try_get(["p=2"])
219
+ assert isinstance(run, Run)
205
220
  run = rc.try_get(p=1)
206
221
  assert isinstance(run, Run)
207
222
  run = rc.try_get(p=-1)
208
223
  assert run is None
224
+ run = rc.try_get(["p=-2"])
225
+ assert run is None
209
226
 
210
227
 
211
228
  def test_get_param_names(rc: RunCollection):
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
@@ -68,11 +68,6 @@ def get_hydra_output_dir(run: Run | None = None) -> Path:
68
68
  raise FileNotFoundError
69
69
 
70
70
 
71
- def get_overrides() -> list[str]:
72
- """Retrieve the overrides for the current run."""
73
- return HydraConfig.get().overrides.task
74
-
75
-
76
71
  def load_config(run: Run) -> DictConfig:
77
72
  """Load the configuration for a given run.
78
73
 
@@ -93,6 +88,11 @@ def load_config(run: Run) -> DictConfig:
93
88
  return OmegaConf.load(path) # type: ignore
94
89
 
95
90
 
91
+ def get_overrides() -> list[str]:
92
+ """Retrieve the overrides for the current run."""
93
+ return HydraConfig.get().overrides.task
94
+
95
+
96
96
  def load_overrides(run: Run) -> list[str]:
97
97
  """Load the overrides for a given run.
98
98
 
File without changes
File without changes
File without changes
File without changes