hydraflow 0.2.5__py3-none-any.whl → 0.2.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hydraflow/progress.py ADDED
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import joblib
6
+ from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
7
+
8
+ if TYPE_CHECKING:
9
+ from collections.abc import Iterable
10
+
11
+
12
+ def progress(
13
+ *iterables: Iterable[int | tuple[int, int]],
14
+ n_jobs: int = -1,
15
+ task_name: str = "#{:0>3}",
16
+ main_task_name: str = "main",
17
+ ) -> None:
18
+ with Progress(
19
+ SpinnerColumn(),
20
+ *Progress.get_default_columns(),
21
+ TimeElapsedColumn(),
22
+ ) as progress:
23
+ n = len(iterables)
24
+
25
+ task_main = progress.add_task(main_task_name, total=None) if n > 1 else None
26
+ tasks = [progress.add_task(task_name.format(i), start=False, total=None) for i in range(n)]
27
+
28
+ total = {}
29
+ completed = {}
30
+
31
+ def func(i: int) -> None:
32
+ completed[i] = 0
33
+ total[i] = None
34
+ progress.start_task(tasks[i])
35
+
36
+ for index in iterables[i]:
37
+ if isinstance(index, tuple):
38
+ completed[i], total[i] = index[0] + 1, index[1]
39
+ else:
40
+ completed[i] = index + 1
41
+
42
+ progress.update(tasks[i], total=total[i], completed=completed[i])
43
+ if task_main is not None:
44
+ if all(t is not None for t in total.values()):
45
+ t = sum(total.values())
46
+ else:
47
+ t = None
48
+ c = sum(completed.values())
49
+ progress.update(task_main, total=t, completed=c)
50
+
51
+ if n > 1:
52
+ it = (joblib.delayed(func)(i) for i in range(n))
53
+ joblib.Parallel(n_jobs, prefer="threads")(it)
54
+
55
+ else:
56
+ func(0)
hydraflow/runs.py CHANGED
@@ -51,13 +51,6 @@ def search_runs(
51
51
  error if ``experiment_names`` is also not ``None`` or ``[]``.
52
52
  ``None`` will default to the active experiment if ``experiment_names``
53
53
  is ``None`` or ``[]``.
54
- experiment_ids (list[str] | None): List of experiment IDs. Search can
55
- work with experiment IDs or experiment names, but not both in the
56
- same call. Values other than ``None`` or ``[]`` will result in
57
- error if ``experiment_names`` is also not ``None`` or ``[]``.
58
- ``experiment_names`` is also not ``None`` or ``[]``. ``None`` will
59
- default to the active experiment if ``experiment_names`` is ``None``
60
- or ``[]``.
61
54
  filter_string (str): Filter query string, defaults to searching all
62
55
  runs.
63
56
  run_view_type (int): one of enum values ``ACTIVE_ONLY``, ``DELETED_ONLY``,
@@ -501,30 +494,28 @@ class RunCollection:
501
494
  """
502
495
  return (func(download_artifacts(run_id=run.info.run_id)) for run in self._runs)
503
496
 
504
- def group_by(
505
- self, names: list[str] | None = None, *args
506
- ) -> dict[tuple[str, ...], RunCollection]:
497
+ def group_by(self, *names: str | list[str]) -> dict[tuple[str | None, ...], RunCollection]:
507
498
  """
508
- Group the runs by the specified parameter names and return a dictionary
509
- where the keys are the parameter values and the values are the runs.
499
+ Group runs by specified parameter names.
500
+
501
+ This method groups the runs in the collection based on the values of the
502
+ specified parameters. Each unique combination of parameter values will
503
+ form a key in the returned dictionary.
510
504
 
511
505
  Args:
512
- names (list[str] | None): The parameter names to group by.
513
- *args: Additional positional arguments to specify parameter names.
506
+ *names (str | list[str]): The names of the parameters to group by.
507
+ This can be a single parameter name or multiple names provided
508
+ as separate arguments or as a list.
514
509
 
515
510
  Returns:
516
- A dictionary where the keys are the parameter values and the values
517
- are the runs.
511
+ dict[tuple[str | None, ...], RunCollection]: A dictionary where the keys
512
+ are tuples of parameter values and the values are RunCollection objects
513
+ containing the runs that match those parameter values.
518
514
  """
519
- names = names[:] if names else []
520
- names.extend(args)
521
-
522
- grouped_runs = {}
515
+ grouped_runs: dict[tuple[str | None, ...], list[Run]] = {}
523
516
  for run in self._runs:
524
- key = get_params(run, names)
525
- if key not in grouped_runs:
526
- grouped_runs[key] = []
527
- grouped_runs[key].append(run)
517
+ key = get_params(run, *names)
518
+ grouped_runs.setdefault(key, []).append(run)
528
519
 
529
520
  return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
530
521
 
@@ -792,11 +783,32 @@ def try_get_run(runs: list[Run], config: object | None = None, **kwargs) -> Run
792
783
  raise ValueError(msg)
793
784
 
794
785
 
795
- def get_params(run: Run, names: list[str] | None = None, *args) -> tuple[str, ...]:
796
- names = names[:] if names else []
797
- names.extend(args)
786
+ def get_params(run: Run, *names: str | list[str]) -> tuple[str | None, ...]:
787
+ """
788
+ Retrieve the values of specified parameters from the given run.
789
+
790
+ This function extracts the values of the parameters identified by the
791
+ provided names from the specified run. It can accept both individual
792
+ parameter names and lists of parameter names.
793
+
794
+ Args:
795
+ run (Run): The run object from which to extract parameter values.
796
+ *names (str | list[str]): The names of the parameters to retrieve.
797
+ This can be a single parameter name or multiple names provided
798
+ as separate arguments or as a list.
798
799
 
799
- return tuple(run.data.params[name] for name in names)
800
+ Returns:
801
+ tuple[str | None, ...]: A tuple containing the values of the specified
802
+ parameters in the order they were provided.
803
+ """
804
+ names_ = []
805
+ for name in names:
806
+ if isinstance(name, list):
807
+ names_.extend(name)
808
+ else:
809
+ names_.append(name)
810
+
811
+ return tuple(run.data.params.get(name) for name in names_)
800
812
 
801
813
 
802
814
  def get_param_names(runs: list[Run]) -> list[str]:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: hydraflow
3
- Version: 0.2.5
3
+ Version: 0.2.6
4
4
  Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
5
5
  Project-URL: Documentation, https://github.com/daizutabi/hydraflow
6
6
  Project-URL: Source, https://github.com/daizutabi/hydraflow
@@ -17,7 +17,9 @@ Classifier: Topic :: Documentation
17
17
  Classifier: Topic :: Software Development :: Documentation
18
18
  Requires-Python: >=3.10
19
19
  Requires-Dist: hydra-core>1.3
20
+ Requires-Dist: joblib
20
21
  Requires-Dist: mlflow>2.15
22
+ Requires-Dist: rich
21
23
  Requires-Dist: setuptools
22
24
  Requires-Dist: watchdog
23
25
  Requires-Dist: watchfiles
@@ -3,8 +3,9 @@ hydraflow/asyncio.py,sha256=yh851L315QHzRBwq6r-uwO2oZKgz1JawHp-fswfxT1E,6175
3
3
  hydraflow/config.py,sha256=6TCKNQZ3sSrIEvl245T2udwFuknejyN1dMcIVmOHdrQ,2102
4
4
  hydraflow/context.py,sha256=8Qn99yCSkCarDDthQ6hjgW80CBBIg0H7fnLvtw4ZXo8,7248
5
5
  hydraflow/mlflow.py,sha256=gGr0fvFEllduA-ByHMeEamM39zVY_30tjtEbkSZ4lHA,3659
6
- hydraflow/runs.py,sha256=41P2aIm7Alem3uKHd-JJdoDzzA4LwrO0rIIZKqZGmdc,31071
7
- hydraflow-0.2.5.dist-info/METADATA,sha256=KDDgZxTmODbd9fSiwLrURTk7il53CQzGkpGrAshPp1s,4139
8
- hydraflow-0.2.5.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- hydraflow-0.2.5.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
10
- hydraflow-0.2.5.dist-info/RECORD,,
6
+ hydraflow/progress.py,sha256=dReFp-AfBuYpjGQnqRmkwPcoyFfe2WCgkklXuo9ZjNg,1709
7
+ hydraflow/runs.py,sha256=TETX54OVJPJLi6rjpNcsXAhXH2Q9unhjXhGkOtFtHng,31559
8
+ hydraflow-0.2.6.dist-info/METADATA,sha256=yOEx7M9jM5M7MNkLOZShO-DexNqXzIHjSkqbxcNMHQ0,4181
9
+ hydraflow-0.2.6.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
10
+ hydraflow-0.2.6.dist-info/licenses/LICENSE,sha256=IGdDrBPqz1O0v_UwCW-NJlbX9Hy9b3uJ11t28y2srmY,1062
11
+ hydraflow-0.2.6.dist-info/RECORD,,