hydraflow 0.2.8__tar.gz → 0.2.10__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- {hydraflow-0.2.8 → hydraflow-0.2.10}/PKG-INFO +1 -1
- {hydraflow-0.2.8 → hydraflow-0.2.10}/pyproject.toml +2 -2
- {hydraflow-0.2.8 → hydraflow-0.2.10}/src/hydraflow/__init__.py +3 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/src/hydraflow/asyncio.py +9 -3
- hydraflow-0.2.10/src/hydraflow/progress.py +191 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/src/hydraflow/run_collection.py +6 -21
- hydraflow-0.2.10/tests/scripts/progress.py +65 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_app.py +0 -9
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_progress.py +1 -1
- hydraflow-0.2.8/src/hydraflow/progress.py +0 -131
- {hydraflow-0.2.8 → hydraflow-0.2.10}/.devcontainer/devcontainer.json +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/.devcontainer/postCreate.sh +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/.devcontainer/starship.toml +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/.gitattributes +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/.gitignore +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/LICENSE +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/README.md +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/src/hydraflow/config.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/src/hydraflow/context.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/src/hydraflow/info.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/src/hydraflow/mlflow.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/scripts/__init__.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/scripts/app.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/scripts/watch.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_asyncio.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_config.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_context.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_info.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_log_run.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_mlflow.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_run_collection.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_version.py +0 -0
- {hydraflow-0.2.8 → hydraflow-0.2.10}/tests/test_watch.py +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: hydraflow
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.10
|
4
4
|
Summary: Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments.
|
5
5
|
Project-URL: Documentation, https://github.com/daizutabi/hydraflow
|
6
6
|
Project-URL: Source, https://github.com/daizutabi/hydraflow
|
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "hydraflow"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.10"
|
8
8
|
description = "Hydraflow integrates Hydra and MLflow to manage and track machine learning experiments."
|
9
9
|
readme = "README.md"
|
10
10
|
license = "MIT"
|
@@ -63,7 +63,7 @@ asyncio_default_fixture_loop_scope = "function"
|
|
63
63
|
exclude_lines = ["no cov", "raise NotImplementedError", "if TYPE_CHECKING:"]
|
64
64
|
|
65
65
|
[tool.ruff]
|
66
|
-
line-length =
|
66
|
+
line-length = 88
|
67
67
|
target-version = "py312"
|
68
68
|
|
69
69
|
[tool.ruff.lint]
|
@@ -5,6 +5,7 @@ from .mlflow import (
|
|
5
5
|
search_runs,
|
6
6
|
set_experiment,
|
7
7
|
)
|
8
|
+
from .progress import multi_tasks_progress, parallel_progress
|
8
9
|
from .run_collection import RunCollection
|
9
10
|
|
10
11
|
__all__ = [
|
@@ -15,6 +16,8 @@ __all__ = [
|
|
15
16
|
"list_runs",
|
16
17
|
"load_config",
|
17
18
|
"log_run",
|
19
|
+
"multi_tasks_progress",
|
20
|
+
"parallel_progress",
|
18
21
|
"search_runs",
|
19
22
|
"set_experiment",
|
20
23
|
"start_run",
|
@@ -41,7 +41,9 @@ async def execute_command(
|
|
41
41
|
int: The return code of the process.
|
42
42
|
"""
|
43
43
|
try:
|
44
|
-
process = await asyncio.create_subprocess_exec(
|
44
|
+
process = await asyncio.create_subprocess_exec(
|
45
|
+
program, *args, stdout=PIPE, stderr=PIPE
|
46
|
+
)
|
45
47
|
await asyncio.gather(
|
46
48
|
process_stream(process.stdout, stdout),
|
47
49
|
process_stream(process.stderr, stderr),
|
@@ -100,7 +102,9 @@ async def monitor_file_changes(
|
|
100
102
|
"""
|
101
103
|
str_paths = [str(path) for path in paths]
|
102
104
|
try:
|
103
|
-
async for changes in watchfiles.awatch(
|
105
|
+
async for changes in watchfiles.awatch(
|
106
|
+
*str_paths, stop_event=stop_event, **awatch_kwargs
|
107
|
+
):
|
104
108
|
callback(changes)
|
105
109
|
except Exception as e:
|
106
110
|
logger.error(f"Error watching files: {e}")
|
@@ -129,7 +133,9 @@ async def run_and_monitor(
|
|
129
133
|
"""
|
130
134
|
stop_event = asyncio.Event()
|
131
135
|
run_task = asyncio.create_task(
|
132
|
-
execute_command(
|
136
|
+
execute_command(
|
137
|
+
program, *args, stop_event=stop_event, stdout=stdout, stderr=stderr
|
138
|
+
)
|
133
139
|
)
|
134
140
|
if watch and paths:
|
135
141
|
monitor_task = asyncio.create_task(
|
@@ -0,0 +1,191 @@
|
|
1
|
+
"""
|
2
|
+
Module for managing progress tracking in parallel processing using Joblib
|
3
|
+
and Rich's Progress bar.
|
4
|
+
|
5
|
+
Provide context managers and functions to facilitate the execution
|
6
|
+
of tasks in parallel while displaying progress updates.
|
7
|
+
|
8
|
+
The following key components are provided:
|
9
|
+
|
10
|
+
- JoblibProgress: A context manager for tracking progress with Rich's Progress
|
11
|
+
bar.
|
12
|
+
- parallel_progress: A function to execute a given function in parallel over
|
13
|
+
an iterable with progress tracking.
|
14
|
+
- multi_tasks_progress: A function to render auto-updating progress bars for
|
15
|
+
multiple tasks concurrently.
|
16
|
+
|
17
|
+
Usage:
|
18
|
+
Import the necessary functions and use them to manage progress in your
|
19
|
+
parallel processing tasks.
|
20
|
+
"""
|
21
|
+
|
22
|
+
from __future__ import annotations
|
23
|
+
|
24
|
+
from contextlib import contextmanager
|
25
|
+
from typing import TYPE_CHECKING, TypeVar
|
26
|
+
|
27
|
+
import joblib
|
28
|
+
from rich.progress import Progress
|
29
|
+
|
30
|
+
if TYPE_CHECKING:
|
31
|
+
from collections.abc import Callable, Iterable, Iterator
|
32
|
+
|
33
|
+
from rich.progress import ProgressColumn
|
34
|
+
|
35
|
+
|
36
|
+
# https://github.com/jonghwanhyeon/joblib-progress/blob/main/joblib_progress/__init__.py
|
37
|
+
@contextmanager
|
38
|
+
def JoblibProgress(
|
39
|
+
*columns: ProgressColumn | str,
|
40
|
+
description: str | None = None,
|
41
|
+
total: int | None = None,
|
42
|
+
**kwargs,
|
43
|
+
) -> Iterator[Progress]:
|
44
|
+
"""
|
45
|
+
Context manager for tracking progress using Joblib with Rich's Progress bar.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
*columns (ProgressColumn | str): Columns to display in the progress bar.
|
49
|
+
description (str | None, optional): A description for the progress task.
|
50
|
+
Defaults to None.
|
51
|
+
total (int | None, optional): The total number of tasks. If None, it will
|
52
|
+
be determined automatically.
|
53
|
+
**kwargs: Additional keyword arguments passed to the Progress instance.
|
54
|
+
|
55
|
+
Yields:
|
56
|
+
Progress: A Progress instance for managing the progress bar.
|
57
|
+
|
58
|
+
Example:
|
59
|
+
with JoblibProgress("task", total=100) as progress:
|
60
|
+
# Your parallel processing code here
|
61
|
+
"""
|
62
|
+
if not columns:
|
63
|
+
columns = Progress.get_default_columns()
|
64
|
+
|
65
|
+
progress = Progress(*columns, **kwargs)
|
66
|
+
|
67
|
+
if description is None:
|
68
|
+
description = "Processing..."
|
69
|
+
|
70
|
+
task_id = progress.add_task(description, total=total)
|
71
|
+
print_progress = joblib.parallel.Parallel.print_progress
|
72
|
+
|
73
|
+
def update_progress(self: joblib.parallel.Parallel):
|
74
|
+
progress.update(task_id, completed=self.n_completed_tasks, refresh=True)
|
75
|
+
return print_progress(self)
|
76
|
+
|
77
|
+
try:
|
78
|
+
joblib.parallel.Parallel.print_progress = update_progress
|
79
|
+
progress.start()
|
80
|
+
yield progress
|
81
|
+
|
82
|
+
finally:
|
83
|
+
progress.stop()
|
84
|
+
joblib.parallel.Parallel.print_progress = print_progress
|
85
|
+
|
86
|
+
|
87
|
+
T = TypeVar("T")
|
88
|
+
U = TypeVar("U")
|
89
|
+
|
90
|
+
|
91
|
+
def parallel_progress(
|
92
|
+
func: Callable[[T], U],
|
93
|
+
iterable: Iterable[T],
|
94
|
+
*columns: ProgressColumn | str,
|
95
|
+
n_jobs: int = -1,
|
96
|
+
description: str | None = None,
|
97
|
+
**kwargs,
|
98
|
+
) -> list[U]:
|
99
|
+
"""
|
100
|
+
Execute a function in parallel over an iterable with progress tracking.
|
101
|
+
|
102
|
+
Args:
|
103
|
+
func (Callable[[T], U]): The function to execute on each item in the
|
104
|
+
iterable.
|
105
|
+
iterable (Iterable[T]): An iterable of items to process.
|
106
|
+
*columns (ProgressColumn | str): Additional columns to display in the
|
107
|
+
progress bar.
|
108
|
+
n_jobs (int, optional): The number of jobs to run in parallel.
|
109
|
+
Defaults to -1 (all processors).
|
110
|
+
description (str | None, optional): A description for the progress bar.
|
111
|
+
Defaults to None.
|
112
|
+
**kwargs: Additional keyword arguments passed to the Progress instance.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
list[U]: A list of results from applying the function to each item in
|
116
|
+
the iterable.
|
117
|
+
"""
|
118
|
+
iterable = list(iterable)
|
119
|
+
total = len(iterable)
|
120
|
+
|
121
|
+
with JoblibProgress(*columns, description=description, total=total, **kwargs):
|
122
|
+
it = (joblib.delayed(func)(x) for x in iterable)
|
123
|
+
return joblib.Parallel(n_jobs=n_jobs)(it) # type: ignore
|
124
|
+
|
125
|
+
|
126
|
+
def multi_tasks_progress(
|
127
|
+
iterables: Iterable[Iterable[int | tuple[int, int]]],
|
128
|
+
*columns: ProgressColumn | str,
|
129
|
+
n_jobs: int = -1,
|
130
|
+
description: str = "#{:0>3}",
|
131
|
+
main_description: str = "main",
|
132
|
+
transient: bool | None = None,
|
133
|
+
**kwargs,
|
134
|
+
) -> None:
|
135
|
+
"""
|
136
|
+
Render auto-updating progress bars for multiple tasks concurrently.
|
137
|
+
|
138
|
+
Args:
|
139
|
+
iterables (Iterable[Iterable[int | tuple[int, int]]]): A collection of
|
140
|
+
iterables, each representing a task. Each iterable can yield
|
141
|
+
integers (completed) or tuples of integers (completed, total).
|
142
|
+
*columns (ProgressColumn | str): Additional columns to display in the
|
143
|
+
progress bars.
|
144
|
+
n_jobs (int, optional): Number of jobs to run in parallel. Defaults to
|
145
|
+
-1, which means using all processors.
|
146
|
+
description (str, optional): Format string for describing tasks. Defaults to
|
147
|
+
"#{:0>3}".
|
148
|
+
main_description (str, optional): Description for the main task.
|
149
|
+
Defaults to "main".
|
150
|
+
transient (bool | None, optional): Whether to remove the progress bar
|
151
|
+
after completion. Defaults to None.
|
152
|
+
**kwargs: Additional keyword arguments passed to the Progress instance.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
None
|
156
|
+
"""
|
157
|
+
if not columns:
|
158
|
+
columns = Progress.get_default_columns()
|
159
|
+
|
160
|
+
iterables = list(iterables)
|
161
|
+
|
162
|
+
with Progress(*columns, transient=transient or False, **kwargs) as progress:
|
163
|
+
task_main = progress.add_task(main_description, total=None)
|
164
|
+
total = {}
|
165
|
+
completed = {}
|
166
|
+
|
167
|
+
def func(i: int, iterable: Iterable[int | tuple[int, int]]) -> None:
|
168
|
+
task_id = progress.add_task(description.format(i), total=None)
|
169
|
+
completed[i] = 0
|
170
|
+
total[i] = None
|
171
|
+
|
172
|
+
for index in iterable:
|
173
|
+
if isinstance(index, tuple):
|
174
|
+
completed[i], total[i] = index[0] + 1, index[1]
|
175
|
+
else:
|
176
|
+
completed[i] = index + 1
|
177
|
+
|
178
|
+
progress.update(task_id, total=total[i], completed=completed[i])
|
179
|
+
|
180
|
+
if all(t is not None for t in total.values()):
|
181
|
+
t = sum(total.values())
|
182
|
+
else:
|
183
|
+
t = None
|
184
|
+
c = sum(completed.values())
|
185
|
+
progress.update(task_main, total=t, completed=c)
|
186
|
+
|
187
|
+
if transient is not False:
|
188
|
+
progress.remove_task(task_id)
|
189
|
+
|
190
|
+
it = (joblib.delayed(func)(i, it) for i, it in enumerate(iterables))
|
191
|
+
joblib.Parallel(n_jobs, prefer="threads")(it)
|
@@ -468,7 +468,9 @@ class RunCollection:
|
|
468
468
|
"""
|
469
469
|
return (func(dir, *args, **kwargs) for dir in self.info.artifact_dir)
|
470
470
|
|
471
|
-
def group_by(
|
471
|
+
def group_by(
|
472
|
+
self, *names: str | list[str]
|
473
|
+
) -> dict[tuple[str | None, ...], RunCollection]:
|
472
474
|
"""
|
473
475
|
Group runs by specified parameter names.
|
474
476
|
|
@@ -493,25 +495,6 @@ class RunCollection:
|
|
493
495
|
|
494
496
|
return {key: RunCollection(runs) for key, runs in grouped_runs.items()}
|
495
497
|
|
496
|
-
def group_by_values(self, *names: str | list[str]) -> list[RunCollection]:
|
497
|
-
"""
|
498
|
-
Group runs by specified parameter names.
|
499
|
-
|
500
|
-
This method groups the runs in the collection based on the values of the
|
501
|
-
specified parameters. Each unique combination of parameter values will
|
502
|
-
form a separate RunCollection in the returned list.
|
503
|
-
|
504
|
-
Args:
|
505
|
-
*names (str | list[str]): The names of the parameters to group by.
|
506
|
-
This can be a single parameter name or multiple names provided
|
507
|
-
as separate arguments or as a list.
|
508
|
-
|
509
|
-
Returns:
|
510
|
-
list[RunCollection]: A list of RunCollection objects, where each
|
511
|
-
object contains runs that match the specified parameter values.
|
512
|
-
"""
|
513
|
-
return list(self.group_by(*names).values())
|
514
|
-
|
515
498
|
|
516
499
|
def _param_matches(run: Run, key: str, value: Any) -> bool:
|
517
500
|
"""
|
@@ -671,7 +654,9 @@ def find_last_run(runs: list[Run], config: object | None = None, **kwargs) -> Ru
|
|
671
654
|
return filtered_runs[-1]
|
672
655
|
|
673
656
|
|
674
|
-
def try_find_last_run(
|
657
|
+
def try_find_last_run(
|
658
|
+
runs: list[Run], config: object | None = None, **kwargs
|
659
|
+
) -> Run | None:
|
675
660
|
"""
|
676
661
|
Find the last run based on the provided configuration.
|
677
662
|
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import random
|
4
|
+
import time
|
5
|
+
|
6
|
+
from rich.progress import (
|
7
|
+
MofNCompleteColumn,
|
8
|
+
Progress,
|
9
|
+
SpinnerColumn,
|
10
|
+
TimeElapsedColumn,
|
11
|
+
)
|
12
|
+
|
13
|
+
from hydraflow import multi_tasks_progress, parallel_progress
|
14
|
+
|
15
|
+
|
16
|
+
def test_parallel_progress(**kwargs):
|
17
|
+
def func(x: int) -> str:
|
18
|
+
time.sleep(1)
|
19
|
+
return f"result: {x}"
|
20
|
+
|
21
|
+
it = range(12)
|
22
|
+
|
23
|
+
columns = [
|
24
|
+
SpinnerColumn(),
|
25
|
+
*Progress.get_default_columns(),
|
26
|
+
MofNCompleteColumn(),
|
27
|
+
TimeElapsedColumn(),
|
28
|
+
]
|
29
|
+
|
30
|
+
parallel_progress(func, it, *columns, n_jobs=-1, **kwargs)
|
31
|
+
|
32
|
+
|
33
|
+
def task(total):
|
34
|
+
for i in range(total or 90):
|
35
|
+
if total is None:
|
36
|
+
yield i
|
37
|
+
else:
|
38
|
+
yield i, total
|
39
|
+
time.sleep(random.random() / 30)
|
40
|
+
|
41
|
+
|
42
|
+
def test_multi_tasks_progress(total: bool, **kwargs):
|
43
|
+
tasks = (task(random.randint(80, 100)) for _ in range(4))
|
44
|
+
if total:
|
45
|
+
tasks = (task(None), *list(tasks)[:2], task(None))
|
46
|
+
|
47
|
+
columns = [
|
48
|
+
SpinnerColumn(),
|
49
|
+
*Progress.get_default_columns(),
|
50
|
+
MofNCompleteColumn(),
|
51
|
+
TimeElapsedColumn(),
|
52
|
+
]
|
53
|
+
|
54
|
+
if total:
|
55
|
+
kwargs["main_description"] = "unknown"
|
56
|
+
|
57
|
+
multi_tasks_progress(tasks, *columns, n_jobs=4, **kwargs)
|
58
|
+
|
59
|
+
|
60
|
+
if __name__ == "__main__":
|
61
|
+
test_parallel_progress(description="parallel")
|
62
|
+
test_parallel_progress(transient=True)
|
63
|
+
test_multi_tasks_progress(False)
|
64
|
+
test_multi_tasks_progress(True, transient=False)
|
65
|
+
test_multi_tasks_progress(False, transient=True)
|
@@ -98,12 +98,3 @@ def test_app_group_by(rc: RunCollection):
|
|
98
98
|
assert grouped[("x",)].info.params[1] == {"port": "2", "host": "x"}
|
99
99
|
assert grouped[("y",)].info.params[0] == {"port": "1", "host": "y"}
|
100
100
|
assert grouped[("y",)].info.params[1] == {"port": "2", "host": "y"}
|
101
|
-
|
102
|
-
|
103
|
-
def test_app_group_by_values(rc: RunCollection):
|
104
|
-
grouped = rc.group_by_values("port")
|
105
|
-
assert len(grouped) == 2
|
106
|
-
assert grouped[0].info.params[0] == {"port": "1", "host": "x"}
|
107
|
-
assert grouped[0].info.params[1] == {"port": "1", "host": "y"}
|
108
|
-
assert grouped[1].info.params[0] == {"port": "2", "host": "x"}
|
109
|
-
assert grouped[1].info.params[1] == {"port": "2", "host": "y"}
|
@@ -8,5 +8,5 @@ import pytest
|
|
8
8
|
sys.platform == "win32", reason="'cp932' codec can't encode character '\\u2807'"
|
9
9
|
)
|
10
10
|
def test_progress_bar():
|
11
|
-
cp = run([sys.executable, "
|
11
|
+
cp = run([sys.executable, "tests/scripts/progress.py"])
|
12
12
|
assert cp.returncode == 0
|
@@ -1,131 +0,0 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
|
-
from typing import TYPE_CHECKING
|
4
|
-
|
5
|
-
import joblib
|
6
|
-
from rich.progress import Progress
|
7
|
-
|
8
|
-
if TYPE_CHECKING:
|
9
|
-
from collections.abc import Iterable
|
10
|
-
|
11
|
-
from rich.progress import ProgressColumn
|
12
|
-
|
13
|
-
|
14
|
-
def multi_task_progress(
|
15
|
-
iterables: Iterable[Iterable[int | tuple[int, int]]],
|
16
|
-
*columns: ProgressColumn | str,
|
17
|
-
n_jobs: int = -1,
|
18
|
-
description: str = "#{:0>3}",
|
19
|
-
main_description: str = "main",
|
20
|
-
transient: bool | None = None,
|
21
|
-
**kwargs,
|
22
|
-
) -> None:
|
23
|
-
"""
|
24
|
-
Render auto-updating progress bars for multiple tasks concurrently.
|
25
|
-
|
26
|
-
Args:
|
27
|
-
iterables (Iterable[Iterable[int | tuple[int, int]]]): A collection of
|
28
|
-
iterables, each representing a task. Each iterable can yield
|
29
|
-
integers (completed) or tuples of integers (completed, total).
|
30
|
-
*columns (ProgressColumn | str): Additional columns to display in the
|
31
|
-
progress bars.
|
32
|
-
n_jobs (int, optional): Number of jobs to run in parallel. Defaults to
|
33
|
-
-1, which means using all processors.
|
34
|
-
description (str, optional): Format string for describing tasks. Defaults to
|
35
|
-
"#{:0>3}".
|
36
|
-
main_description (str, optional): Description for the main task.
|
37
|
-
Defaults to "main".
|
38
|
-
transient (bool | None, optional): Whether to remove the progress bar
|
39
|
-
after completion. Defaults to None.
|
40
|
-
**kwargs: Additional keyword arguments passed to the Progress instance.
|
41
|
-
|
42
|
-
Returns:
|
43
|
-
None
|
44
|
-
"""
|
45
|
-
if not columns:
|
46
|
-
columns = Progress.get_default_columns()
|
47
|
-
|
48
|
-
iterables = list(iterables)
|
49
|
-
|
50
|
-
with Progress(*columns, transient=transient or False, **kwargs) as progress:
|
51
|
-
n = len(iterables)
|
52
|
-
|
53
|
-
task_main = progress.add_task(main_description, total=None) if n > 1 else None
|
54
|
-
tasks = [
|
55
|
-
progress.add_task(description.format(i), start=False, total=None) for i in range(n)
|
56
|
-
]
|
57
|
-
|
58
|
-
total = {}
|
59
|
-
completed = {}
|
60
|
-
|
61
|
-
def func(i: int) -> None:
|
62
|
-
completed[i] = 0
|
63
|
-
total[i] = None
|
64
|
-
progress.start_task(tasks[i])
|
65
|
-
|
66
|
-
for index in iterables[i]:
|
67
|
-
if isinstance(index, tuple):
|
68
|
-
completed[i], total[i] = index[0] + 1, index[1]
|
69
|
-
else:
|
70
|
-
completed[i] = index + 1
|
71
|
-
|
72
|
-
progress.update(tasks[i], total=total[i], completed=completed[i])
|
73
|
-
if task_main is not None:
|
74
|
-
if all(t is not None for t in total.values()):
|
75
|
-
t = sum(total.values())
|
76
|
-
else:
|
77
|
-
t = None
|
78
|
-
c = sum(completed.values())
|
79
|
-
progress.update(task_main, total=t, completed=c)
|
80
|
-
|
81
|
-
if transient or n > 1:
|
82
|
-
progress.remove_task(tasks[i])
|
83
|
-
|
84
|
-
if n > 1:
|
85
|
-
it = (joblib.delayed(func)(i) for i in range(n))
|
86
|
-
joblib.Parallel(n_jobs, prefer="threads")(it)
|
87
|
-
|
88
|
-
else:
|
89
|
-
func(0)
|
90
|
-
|
91
|
-
|
92
|
-
if __name__ == "__main__":
|
93
|
-
import random
|
94
|
-
import time
|
95
|
-
|
96
|
-
from rich.progress import MofNCompleteColumn, Progress, SpinnerColumn, TimeElapsedColumn
|
97
|
-
|
98
|
-
from hydraflow.progress import multi_task_progress
|
99
|
-
|
100
|
-
def task(total):
|
101
|
-
for i in range(total or 90):
|
102
|
-
if total is None:
|
103
|
-
yield i
|
104
|
-
else:
|
105
|
-
yield i, total
|
106
|
-
time.sleep(random.random() / 30)
|
107
|
-
|
108
|
-
def multi_task_progress_test(unknown_total: bool):
|
109
|
-
tasks = [task(random.randint(80, 100)) for _ in range(4)]
|
110
|
-
if unknown_total:
|
111
|
-
tasks = [task(None), *tasks, task(None)]
|
112
|
-
|
113
|
-
columns = [
|
114
|
-
SpinnerColumn(),
|
115
|
-
*Progress.get_default_columns(),
|
116
|
-
MofNCompleteColumn(),
|
117
|
-
TimeElapsedColumn(),
|
118
|
-
]
|
119
|
-
|
120
|
-
kwargs = {}
|
121
|
-
if unknown_total:
|
122
|
-
kwargs["main_description"] = "unknown"
|
123
|
-
|
124
|
-
multi_task_progress(tasks, *columns, n_jobs=4, **kwargs)
|
125
|
-
|
126
|
-
multi_task_progress_test(False)
|
127
|
-
multi_task_progress_test(True)
|
128
|
-
multi_task_progress([task(100)])
|
129
|
-
multi_task_progress([task(None)], description="unknown")
|
130
|
-
multi_task_progress([task(100), task(None)], main_description="transient", transient=True)
|
131
|
-
multi_task_progress([task(100)], description="transient", transient=True)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|