hydraflow 0.2.5__tar.gz → 0.2.6__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (29) hide show
  1. {hydraflow-0.2.5 → hydraflow-0.2.6}/PKG-INFO +3 -1
  2. {hydraflow-0.2.5 → hydraflow-0.2.6}/pyproject.toml +3 -1
  3. hydraflow-0.2.6/src/hydraflow/progress.py +56 -0
  4. {hydraflow-0.2.5 → hydraflow-0.2.6}/src/hydraflow/runs.py +40 -28
  5. hydraflow-0.2.6/tests/scripts/progress.py +22 -0
  6. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_asyncio.py +1 -0
  7. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_context.py +20 -12
  8. hydraflow-0.2.6/tests/test_progress.py +0 -0
  9. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_runs.py +54 -6
  10. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_watch.py +4 -2
  11. {hydraflow-0.2.5 → hydraflow-0.2.6}/.devcontainer/devcontainer.json +0 -0
  12. {hydraflow-0.2.5 → hydraflow-0.2.6}/.devcontainer/postCreate.sh +0 -0
  13. {hydraflow-0.2.5 → hydraflow-0.2.6}/.devcontainer/starship.toml +0 -0
  14. {hydraflow-0.2.5 → hydraflow-0.2.6}/.gitattributes +0 -0
  15. {hydraflow-0.2.5 → hydraflow-0.2.6}/.gitignore +0 -0
  16. {hydraflow-0.2.5 → hydraflow-0.2.6}/LICENSE +0 -0
  17. {hydraflow-0.2.5 → hydraflow-0.2.6}/README.md +0 -0
  18. {hydraflow-0.2.5 → hydraflow-0.2.6}/src/hydraflow/__init__.py +0 -0
  19. {hydraflow-0.2.5 → hydraflow-0.2.6}/src/hydraflow/asyncio.py +0 -0
  20. {hydraflow-0.2.5 → hydraflow-0.2.6}/src/hydraflow/config.py +0 -0
  21. {hydraflow-0.2.5 → hydraflow-0.2.6}/src/hydraflow/context.py +0 -0
  22. {hydraflow-0.2.5 → hydraflow-0.2.6}/src/hydraflow/mlflow.py +0 -0
  23. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/scripts/__init__.py +0 -0
  24. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/scripts/log_run.py +0 -0
  25. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/scripts/watch.py +0 -0
  26. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_config.py +0 -0
  27. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_log_run.py +0 -0
  28. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_mlflow.py +0 -0
  29. {hydraflow-0.2.5 → hydraflow-0.2.6}/tests/test_version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -17,7 +17,9 @@ Classifier: Topic :: Documentation
17
17
  Classifier: Topic :: Software Development :: Documentation
18
18
  Requires-Python: >=3.10
19
19
  Requires-Dist: hydra-core>1.3
20
+ Requires-Dist: joblib
20
21
  Requires-Dist: mlflow>2.15
22
+ Requires-Dist: rich
21
23
  Requires-Dist: setuptools
22
24
  Requires-Dist: watchdog
23
25
  Requires-Dist: watchfiles
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "hydraflow"
7
- version = "0.2.5"
7
+ version = "0.2.6"
8
8
  description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
9
9
  readme = "README.md"
10
10
  license = "MIT"
@@ -21,7 +21,9 @@ classifiers = [
21
21
  requires-python = ">=3.10"
22
22
  dependencies = [
23
23
  "hydra-core>1.3",
24
+ "joblib",
24
25
  "mlflow>2.15",
26
+ "rich",
25
27
  "setuptools",
26
28
  "watchdog",
27
29
  "watchfiles",
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import joblib
6
+ from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Iterable
10
+
11
+
12
+ def progress(
13
+ *iterables: Iterable[int | tuple[int, int]],
14
+ n_jobs: int = -1,
15
+ task_name: str = "#{:0>3}",
16
+ main_task_name: str = "main",
17
+ ) -> None:
18
+ with Progress(
19
+ SpinnerColumn(),
20
+ *Progress.get_default_columns(),
21
+ TimeElapsedColumn(),
22
+ ) as progress:
23
+ n = len(iterables)
24
+
25
+ task_main = progress.add_task(main_task_name, total=None) if n > 1 else None
26
+ tasks = [progress.add_task(task_name.format(i), start=False, total=None) for i in range(n)]
27
+
28
+ total = {}
29
+ completed = {}
30
+
31
+ def func(i: int) -> None:
32
+ completed[i] = 0
33
+ total[i] = None
34
+ progress.start_task(tasks[i])
35
+
36
+ for index in iterables[i]:
37
+ if isinstance(index, tuple):
38
+ completed[i], total[i] = index[0] + 1, index[1]
39
+ else:
40
+ completed[i] = index + 1
41
+
42
+ progress.update(tasks[i], total=total[i], completed=completed[i])
43
+ if task_main is not None:
44
+ if all(t is not None for t in total.values()):
45
+ t = sum(total.values())
46
+ else:
47
+ t = None
48
+ c = sum(completed.values())
49
+ progress.update(task_main, total=t, completed=c)
50
+
51
+ if n > 1:
52
+ it = (joblib.delayed(func)(i) for i in range(n))
53
+ joblib.Parallel(n_jobs, prefer="threads")(it)
54
+
55
+ else:
56
+ func(0)
@@ -51,13 +51,6 @@ def search_runs(
51
51
  error if ``experiment_names`` is also not ``None`` or ``[]``.
52
52
  ``None`` will default to the active experiment if ``experiment_names``
53
53
  is ``None`` or ``[]``.
54
- experiment_ids (list[str] | None): List of experiment IDs. Search can
55
- work with experiment IDs or experiment names, but not both in the
56
- same call. Values other than ``None`` or ``[]`` will result in
57
- error if ``experiment_names`` is also not ``None`` or ``[]``.
58
- ``experiment_names`` is also not ``None`` or ``[]``. ``None`` will
59
- default to the active experiment if ``experiment_names`` is ``None``
60
- or ``[]``.
61
54
  filter_string (str): Filter query string, defaults to searching all
62
55
  runs.
63
56
  run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
@@ -501,30 +494,28 @@ class RunCollection:
501
494
  """
502
495
  return (func(download_artifacts(run_id=run.info.run_id)) for run in self._runs)
503
496
 
504
- def group_by(
505
- self, names: list[str] | None = None, *args
506
- ) -> dict[tuple[str, ...], RunCollection]:
497
+ def group_by(self, *names: str | list[str]) -> dict[tuple[str | None, ...], RunCollection]:
507
498
  """
508
- Group the runs by the specified parameter names and return a dictionary
509
- where the keys are the parameter values and the values are the runs.
499
+ Group runs by specified parameter names.
500
+
501
+ This method groups the runs in the collection based on the values of the
502
+ specified parameters. Each unique combination of parameter values will
503
+ form a key in the returned dictionary.
510
504
 
511
505
  Args:
512
- names (list[str] | None): The parameter names to group by.
513
- *args: Additional positional arguments to specify parameter names.
506
+ *names (str | list[str]): The names of the parameters to group by.
507
+ This can be a single parameter name or multiple names provided
508
+ as separate arguments or as a list.
514
509
 
515
510
  Returns:
516
- A dictionary where the keys are the parameter values and the values
517
- are the runs.
511
+ dict[tuple[str | None, ...], RunCollection]: A dictionary where the keys
512
+ are tuples of parameter values and the values are RunCollection objects
513
+ containing the runs that match those parameter values.
518
514
  """
519
- names = names[:] if names else []
520
- names.extend(args)
521
-
522
- grouped_runs = {}
515
+ grouped_runs: dict[tuple[str | None, ...], list[Run]] = {}
523
516
  for run in self._runs:
524
- key = get_params(run, names)
525
- if key not in grouped_runs:
526
- grouped_runs[key] = []
527
- grouped_runs[key].append(run)
517
+ key = get_params(run, *names)
518
+ grouped_runs.setdefault(key, []).append(run)
528
519
 
529
520
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
530
521
 
@@ -792,11 +783,32 @@ def try_get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
792
783
  raise ValueError(msg)
793
784
 
794
785
 
795
- def get_params(run: Run, names: list[str] | None = None, *args) -> tuple[str, ...]:
796
- names = names[:] if names else []
797
- names.extend(args)
786
+ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
787
+ """
788
+ Retrieve the values of specified parameters from the given run.
789
+
790
+ This function extracts the values of the parameters identified by the
791
+ provided names from the specified run. It can accept both individual
792
+ parameter names and lists of parameter names.
793
+
794
+ Args:
795
+ run (Run): The run object from which to extract parameter values.
796
+ *names (str | list[str]): The names of the parameters to retrieve.
797
+ This can be a single parameter name or multiple names provided
798
+ as separate arguments or as a list.
798
799
 
799
- return tuple(run.data.params[name] for name in names)
800
+ Returns:
801
+ tuple[str | None, ...]: A tuple containing the values of the specified
802
+ parameters in the order they were provided.
803
+ """
804
+ names_ = []
805
+ for name in names:
806
+ if isinstance(name, list):
807
+ names_.extend(name)
808
+ else:
809
+ names_.append(name)
810
+
811
+ return tuple(run.data.params.get(name) for name in names_)
800
812
 
801
813
 
802
814
  def get_param_names(runs: list[Run]) -> list[str]:
@@ -0,0 +1,22 @@
1
+ import random
2
+ import time
3
+
4
+ from hydraflow.progress import progress
5
+
6
+
7
+ def task(total):
8
+ def func():
9
+ for i in range(total):
10
+ yield i, total
11
+ time.sleep(random.random())
12
+
13
+ return func()
14
+
15
+
16
+ def main():
17
+ tasks = [task(random.randint(10, 20)) for _ in range(12)]
18
+ progress(*tasks, n_jobs=4)
19
+
20
+
21
+ if __name__ == "__main__":
22
+ main()
@@ -77,6 +77,7 @@ async def test_monitor_file_changes(tmp_path: Path, write_soon: Callable[[Path],
77
77
  await asyncio.sleep(1)
78
78
  stop_event.set()
79
79
  await monitor_task
80
+ await asyncio.sleep(1)
80
81
 
81
82
  assert len(changes_detected) > 0
82
83
 
@@ -1,3 +1,5 @@
1
+ import time
2
+ from pathlib import Path
1
3
  from unittest.mock import MagicMock, patch
2
4
 
3
5
  import mlflow
@@ -17,7 +19,7 @@ def runs(monkeypatch, tmp_path):
17
19
  patch("hydraflow.context.HydraConfig.get") as mock_hydra_config,
18
20
  patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
19
21
  ):
20
- mock_hydra_config.return_value.runtime.output_dir = "/tmp"
22
+ mock_hydra_config.return_value.runtime.output_dir = tmp_path.as_posix()
21
23
  mock_log_artifacts.return_value = None
22
24
 
23
25
  mlflow.set_experiment("test_run")
@@ -49,7 +51,7 @@ def test_runs_params_dict(runs: RunCollection, i: int):
49
51
  assert runs[i].data.params["d.i"] == str(i)
50
52
 
51
53
 
52
- def test_log_run_error_handling():
54
+ def test_log_run_error_handling(tmp_path: Path):
53
55
  config = MagicMock()
54
56
  config.some_param = "value"
55
57
 
@@ -59,7 +61,7 @@ def test_log_run_error_handling():
59
61
  patch("hydraflow.context.mlflow.log_artifacts") as mock_log_artifacts,
60
62
  ):
61
63
  mock_log_params.side_effect = Exception("Test exception")
62
- mock_hydra_config.return_value.runtime.output_dir = "/tmp"
64
+ mock_hydra_config.return_value.runtime.output_dir = tmp_path.as_posix()
63
65
  mock_log_artifacts.return_value = None
64
66
 
65
67
  with pytest.raises(Exception, match="Test exception"):
@@ -67,14 +69,20 @@ def test_log_run_error_handling():
67
69
  pass
68
70
 
69
71
 
70
- def test_watch_error_handling():
71
- func = MagicMock()
72
- dir = "/tmp"
72
+ def test_watch_context_manager(tmp_path: Path):
73
+ test_dir = tmp_path / "test_watch"
74
+ test_dir.mkdir(parents=True, exist_ok=True)
75
+ test_file = test_dir / "test_file.txt"
73
76
 
74
- with patch("hydraflow.context.Observer") as mock_observer:
75
- mock_observer_instance = mock_observer.return_value
76
- mock_observer_instance.start.side_effect = Exception("Test exception")
77
+ called = []
77
78
 
78
- with pytest.raises(Exception, match="Test exception"):
79
- with watch(func, dir):
80
- pass
79
+ def mock_func(path: Path):
80
+ assert path == test_file
81
+ called.append(path)
82
+
83
+ with watch(mock_func, test_dir):
84
+ test_file.write_text("new content")
85
+ time.sleep(1)
86
+
87
+ assert len(called) == 1
88
+ assert called[0] == test_file
File without changes
@@ -170,6 +170,16 @@ def test_try_get_run_error(run_list: list[Run]):
170
170
  try_get_run(run_list, {"q": 0})
171
171
 
172
172
 
173
+ def test_get_params(run_list: list[Run]):
174
+ from hydraflow.runs import get_params
175
+
176
+ assert get_params(run_list[1], "p") == ("1",)
177
+ assert get_params(run_list[2], "p", "q") == ("2", "0")
178
+ assert get_params(run_list[3], ["p", "q"]) == ("3", "0")
179
+ assert get_params(run_list[4], "p", ["q", "r"]) == ("4", "0", "1")
180
+ assert get_params(run_list[5], ["a", "q"], "r") == (None, "None", "2")
181
+
182
+
173
183
  def test_get_param_names(run_list: list[Run]):
174
184
  from hydraflow.runs import get_param_names
175
185
 
@@ -427,15 +437,53 @@ def test_run_collection_group_by(runs: RunCollection):
427
437
  assert grouped[("0",)][0] == runs[0]
428
438
  assert grouped[("1",)][0] == runs[1]
429
439
 
430
- grouped = runs.group_by(["q"])
440
+ grouped = runs.group_by("q")
431
441
  assert len(grouped) == 2
432
442
 
433
- grouped = runs.group_by(["r"])
443
+ grouped = runs.group_by("r")
434
444
  assert len(grouped) == 3
435
445
 
436
446
 
437
- # def test_hydra_output_dir_error(runs_list: list[Run]):
438
- # from hydraflow.runs import get_hydra_output_dir
447
+ def test_filter_runs_empty_list():
448
+ from hydraflow.runs import filter_runs
449
+
450
+ x = filter_runs([], p=[0, 1, 2])
451
+ assert x == []
452
+
453
+
454
+ def test_filter_runs_no_match(run_list: list[Run]):
455
+ from hydraflow.runs import filter_runs
456
+
457
+ x = filter_runs(run_list, p=[10, 11, 12])
458
+ assert x == []
459
+
460
+
461
+ def test_get_run_no_match(run_list: list[Run]):
462
+ from hydraflow.runs import get_run
463
+
464
+ with pytest.raises(ValueError):
465
+ get_run(run_list, {"p": 10})
466
+
439
467
 
440
- # with pytest.raises(FileNotFoundError):
441
- # get_hydra_output_dir(runs_list[0])
468
+ def test_get_run_multiple_params(run_list: list[Run]):
469
+ from hydraflow.runs import get_run
470
+
471
+ run = get_run(run_list, {"p": 4, "q": 0})
472
+ assert isinstance(run, Run)
473
+ assert run.data.params["p"] == "4"
474
+ assert run.data.params["q"] == "0"
475
+
476
+
477
+ def test_try_get_run_no_match(run_list: list[Run]):
478
+ from hydraflow.runs import try_get_run
479
+
480
+ assert try_get_run(run_list, {"p": 10}) is None
481
+
482
+
483
+ def test_try_get_run_multiple_params(run_list: list[Run]):
484
+ from hydraflow.runs import try_get_run
485
+
486
+ run = try_get_run(run_list, {"p": 4, "q": 0})
487
+ assert isinstance(run, Run)
488
+ assert run.data.params["p"] == "4"
489
+ assert run.data.params["q"] == "0"
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import subprocess
4
+ import time
4
5
  from pathlib import Path
5
6
 
6
7
  import pytest
@@ -21,6 +22,7 @@ def test_watch(dir, monkeypatch, tmp_path):
21
22
 
22
23
  with watch(func, dir if isinstance(dir, str) else dir()):
23
24
  subprocess.check_call(["python", file])
25
+ time.sleep(1)
24
26
 
25
- assert results[0][0] == "watch.txt"
26
- assert results[0][1] == "watch"
27
+ assert results[0][0] == "watch.txt" # type: ignore
28
+ assert results[0][1] == "watch" # type: ignore
File without changes
File without changes
File without changes
File without changes