glow 0.15.6__tar.gz → 0.15.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {glow-0.15.6 → glow-0.15.8}/PKG-INFO +16 -14
- {glow-0.15.6 → glow-0.15.8}/pyproject.toml +8 -6
- {glow-0.15.6 → glow-0.15.8}/src/glow/__init__.py +4 -1
- {glow-0.15.6 → glow-0.15.8}/src/glow/_async.py +50 -8
- {glow-0.15.6 → glow-0.15.8}/src/glow/_async.pyi +6 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_parallel.py +109 -82
- {glow-0.15.6 → glow-0.15.8}/src/glow/_parallel.pyi +1 -1
- {glow-0.15.6 → glow-0.15.8}/src/glow/_reduction.py +0 -5
- {glow-0.15.6 → glow-0.15.8}/src/glow/_thread_quota.py +16 -2
- {glow-0.15.6 → glow-0.15.8}/src/glow/cli.py +46 -14
- {glow-0.15.6 → glow-0.15.8}/src/glow/cli.pyi +6 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_cli.py +21 -2
- {glow-0.15.6 → glow-0.15.8}/.gitignore +0 -0
- {glow-0.15.6 → glow-0.15.8}/LICENSE +0 -0
- {glow-0.15.6 → glow-0.15.8}/README.md +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_array.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_cache.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_cache.pyi +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_concurrency.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_concurrency.pyi +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_coro.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_debug.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_dev.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_futures.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_ic.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_import_hook.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_imutil.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_keys.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_logging.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_more.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_patch_len.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_patch_print.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_patch_scipy.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_profile.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_profile.pyi +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_repr.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_reusable.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_sizeof.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_streams.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_types.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_uuid.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/_wrap.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/api/__init__.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/api/config.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/api/exporting.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/io/__init__.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/io/_sound.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/io/_svg.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/src/glow/py.typed +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/__init__.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_api.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_batch.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_buffered.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_iter.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_shm.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_thread_pool.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_timed.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_timer.py +0 -0
- {glow-0.15.6 → glow-0.15.8}/test/test_uuid.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: glow
|
|
3
|
-
Version: 0.15.
|
|
3
|
+
Version: 0.15.8
|
|
4
4
|
Summary: Functional Python tools
|
|
5
5
|
Project-URL: homepage, https://github.com/arquolo/glow
|
|
6
6
|
Author-email: Paul Maevskikh <arquolo@gmail.com>
|
|
@@ -33,13 +33,15 @@ Classifier: Operating System :: OS Independent
|
|
|
33
33
|
Classifier: Programming Language :: Python :: 3
|
|
34
34
|
Classifier: Programming Language :: Python :: 3.12
|
|
35
35
|
Classifier: Programming Language :: Python :: 3.13
|
|
36
|
-
|
|
36
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
37
|
+
Requires-Python: <3.15,>=3.12
|
|
37
38
|
Requires-Dist: loguru
|
|
38
39
|
Requires-Dist: loky~=3.1
|
|
39
40
|
Requires-Dist: lxml
|
|
40
41
|
Requires-Dist: numpy<3,>=1.21
|
|
41
42
|
Requires-Dist: tqdm
|
|
42
43
|
Requires-Dist: typing-extensions~=4.1; python_version < '3.11'
|
|
44
|
+
Requires-Dist: typing-inspection~=0.4.1
|
|
43
45
|
Requires-Dist: wrapt~=1.15
|
|
44
46
|
Provides-Extra: all
|
|
45
47
|
Requires-Dist: asttokens; extra == 'all'
|
|
@@ -51,7 +53,7 @@ Requires-Dist: pygments; extra == 'all'
|
|
|
51
53
|
Requires-Dist: sounddevice; extra == 'all'
|
|
52
54
|
Requires-Dist: soundfile; extra == 'all'
|
|
53
55
|
Provides-Extra: dev
|
|
54
|
-
Requires-Dist: black~=
|
|
56
|
+
Requires-Dist: black~=26.1; extra == 'dev'
|
|
55
57
|
Requires-Dist: flake8-alphabetize; extra == 'dev'
|
|
56
58
|
Requires-Dist: flake8-pie; extra == 'dev'
|
|
57
59
|
Requires-Dist: flake8-pyi; extra == 'dev'
|
|
@@ -59,34 +61,34 @@ Requires-Dist: flake8-pyproject; extra == 'dev'
|
|
|
59
61
|
Requires-Dist: flake8-simplify; extra == 'dev'
|
|
60
62
|
Requires-Dist: flake8~=7.0; extra == 'dev'
|
|
61
63
|
Requires-Dist: isort; extra == 'dev'
|
|
62
|
-
Requires-Dist: mypy~=1.
|
|
64
|
+
Requires-Dist: mypy~=1.19; extra == 'dev'
|
|
63
65
|
Requires-Dist: pytest-asyncio; extra == 'dev'
|
|
64
|
-
Requires-Dist: pytest~=
|
|
65
|
-
Requires-Dist: ruff~=0.
|
|
66
|
+
Requires-Dist: pytest~=9.0; extra == 'dev'
|
|
67
|
+
Requires-Dist: ruff~=0.15.0; extra == 'dev'
|
|
66
68
|
Provides-Extra: dev-core
|
|
67
|
-
Requires-Dist: black~=
|
|
69
|
+
Requires-Dist: black~=26.1; extra == 'dev-core'
|
|
68
70
|
Requires-Dist: flake8-pie; extra == 'dev-core'
|
|
69
71
|
Requires-Dist: flake8-pyi; extra == 'dev-core'
|
|
70
72
|
Requires-Dist: flake8-pyproject; extra == 'dev-core'
|
|
71
73
|
Requires-Dist: flake8-simplify; extra == 'dev-core'
|
|
72
74
|
Requires-Dist: flake8~=7.0; extra == 'dev-core'
|
|
73
75
|
Requires-Dist: isort; extra == 'dev-core'
|
|
74
|
-
Requires-Dist: mypy~=1.
|
|
76
|
+
Requires-Dist: mypy~=1.19; extra == 'dev-core'
|
|
75
77
|
Requires-Dist: pytest-asyncio; extra == 'dev-core'
|
|
76
|
-
Requires-Dist: pytest~=
|
|
77
|
-
Requires-Dist: ruff~=0.
|
|
78
|
+
Requires-Dist: pytest~=9.0; extra == 'dev-core'
|
|
79
|
+
Requires-Dist: ruff~=0.15.0; extra == 'dev-core'
|
|
78
80
|
Provides-Extra: dev-wemake
|
|
79
|
-
Requires-Dist: black~=
|
|
81
|
+
Requires-Dist: black~=26.1; extra == 'dev-wemake'
|
|
80
82
|
Requires-Dist: flake8-pie; extra == 'dev-wemake'
|
|
81
83
|
Requires-Dist: flake8-pyi; extra == 'dev-wemake'
|
|
82
84
|
Requires-Dist: flake8-pyproject; extra == 'dev-wemake'
|
|
83
85
|
Requires-Dist: flake8-simplify; extra == 'dev-wemake'
|
|
84
86
|
Requires-Dist: flake8~=7.0; extra == 'dev-wemake'
|
|
85
87
|
Requires-Dist: isort; extra == 'dev-wemake'
|
|
86
|
-
Requires-Dist: mypy~=1.
|
|
88
|
+
Requires-Dist: mypy~=1.19; extra == 'dev-wemake'
|
|
87
89
|
Requires-Dist: pytest-asyncio; extra == 'dev-wemake'
|
|
88
|
-
Requires-Dist: pytest~=
|
|
89
|
-
Requires-Dist: ruff~=0.
|
|
90
|
+
Requires-Dist: pytest~=9.0; extra == 'dev-wemake'
|
|
91
|
+
Requires-Dist: ruff~=0.15.0; extra == 'dev-wemake'
|
|
90
92
|
Requires-Dist: wemake-python-styleguide~=1.3.0; extra == 'dev-wemake'
|
|
91
93
|
Provides-Extra: ic
|
|
92
94
|
Requires-Dist: asttokens; extra == 'ic'
|
|
@@ -7,10 +7,10 @@ only-packages = true
|
|
|
7
7
|
|
|
8
8
|
[project]
|
|
9
9
|
name = "glow"
|
|
10
|
-
version = "0.15.
|
|
10
|
+
version = "0.15.8"
|
|
11
11
|
description = "Functional Python tools"
|
|
12
12
|
readme = "README.md"
|
|
13
|
-
requires-python = ">=3.12"
|
|
13
|
+
requires-python = ">=3.12, <3.15"
|
|
14
14
|
license = {file = "LICENSE"}
|
|
15
15
|
keywords = []
|
|
16
16
|
authors = [
|
|
@@ -26,6 +26,7 @@ classifiers = [
|
|
|
26
26
|
"Programming Language :: Python :: 3",
|
|
27
27
|
"Programming Language :: Python :: 3.12",
|
|
28
28
|
"Programming Language :: Python :: 3.13",
|
|
29
|
+
"Programming Language :: Python :: 3.14",
|
|
29
30
|
]
|
|
30
31
|
dependencies = [
|
|
31
32
|
"loguru",
|
|
@@ -33,6 +34,7 @@ dependencies = [
|
|
|
33
34
|
"lxml",
|
|
34
35
|
"numpy >=1.21, <3",
|
|
35
36
|
"typing-extensions~=4.1; python_version < '3.11'",
|
|
37
|
+
"typing-inspection~=0.4.1",
|
|
36
38
|
"tqdm",
|
|
37
39
|
"wrapt~=1.15",
|
|
38
40
|
]
|
|
@@ -61,17 +63,17 @@ all = [
|
|
|
61
63
|
"matplotlib",
|
|
62
64
|
]
|
|
63
65
|
dev-core = [
|
|
64
|
-
"black~=
|
|
66
|
+
"black~=26.1",
|
|
65
67
|
"flake8~=7.0",
|
|
66
68
|
"flake8-pie",
|
|
67
69
|
"flake8-pyi",
|
|
68
70
|
"flake8-pyproject",
|
|
69
71
|
"flake8-simplify",
|
|
70
72
|
"isort",
|
|
71
|
-
"mypy~=1.
|
|
72
|
-
"pytest~=
|
|
73
|
+
"mypy~=1.19",
|
|
74
|
+
"pytest~=9.0",
|
|
73
75
|
"pytest-asyncio",
|
|
74
|
-
"ruff~=0.
|
|
76
|
+
"ruff~=0.15.0",
|
|
75
77
|
]
|
|
76
78
|
dev = [
|
|
77
79
|
"glow[dev-core]",
|
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING
|
|
|
6
6
|
|
|
7
7
|
from . import _patch_len, _patch_print, _patch_scipy
|
|
8
8
|
from ._array import aceil, afloor, apack, around, pascal
|
|
9
|
-
from ._async import amap, amap_dict, astarmap, astreaming, azip
|
|
9
|
+
from ._async import RwLock, amap, amap_dict, astarmap, astreaming, azip
|
|
10
10
|
from ._cache import cache_status, memoize
|
|
11
11
|
from ._concurrency import (
|
|
12
12
|
call_once,
|
|
@@ -17,6 +17,7 @@ from ._concurrency import (
|
|
|
17
17
|
)
|
|
18
18
|
from ._coro import as_actor, coroutine, summary
|
|
19
19
|
from ._debug import lock_seed, trace, trace_module, whereami
|
|
20
|
+
from ._dev import hide_frame
|
|
20
21
|
from ._import_hook import register_post_import_hook, when_imported
|
|
21
22
|
from ._logging import init_loguru
|
|
22
23
|
from ._more import (
|
|
@@ -69,6 +70,7 @@ else:
|
|
|
69
70
|
|
|
70
71
|
__all__ = [
|
|
71
72
|
'Reusable',
|
|
73
|
+
'RwLock',
|
|
72
74
|
'Uid',
|
|
73
75
|
'aceil',
|
|
74
76
|
'afloor',
|
|
@@ -91,6 +93,7 @@ __all__ = [
|
|
|
91
93
|
'eat',
|
|
92
94
|
'get_executor',
|
|
93
95
|
'groupby',
|
|
96
|
+
'hide_frame',
|
|
94
97
|
'ic',
|
|
95
98
|
'ic_repr',
|
|
96
99
|
'ichunked',
|
|
@@ -1,7 +1,7 @@
|
|
|
1
|
-
__all__ = ['amap', 'amap_dict', 'astarmap', 'azip']
|
|
1
|
+
__all__ = ['RwLock', 'amap', 'amap_dict', 'astarmap', 'azip']
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
|
-
from asyncio import Queue, Task
|
|
4
|
+
from asyncio import Event, Future, Lock, Queue, Task, TaskGroup
|
|
5
5
|
from collections import deque
|
|
6
6
|
from collections.abc import (
|
|
7
7
|
AsyncIterator,
|
|
@@ -12,7 +12,7 @@ from collections.abc import (
|
|
|
12
12
|
Mapping,
|
|
13
13
|
Sequence,
|
|
14
14
|
)
|
|
15
|
-
from contextlib import suppress
|
|
15
|
+
from contextlib import asynccontextmanager, suppress
|
|
16
16
|
from functools import partial
|
|
17
17
|
from typing import TypeGuard, cast, overload
|
|
18
18
|
|
|
@@ -84,7 +84,7 @@ async def astarmap[*Ts, R](
|
|
|
84
84
|
yield await func(*args)
|
|
85
85
|
return
|
|
86
86
|
|
|
87
|
-
async with
|
|
87
|
+
async with TaskGroup() as tg:
|
|
88
88
|
ts = (
|
|
89
89
|
(tg.create_task(func(*args)) for args in iterable)
|
|
90
90
|
if isinstance(iterable, Iterable)
|
|
@@ -256,8 +256,8 @@ def astreaming[T, R](
|
|
|
256
256
|
|
|
257
257
|
buf: list[Job[T, R]] = []
|
|
258
258
|
deadline = float('-inf')
|
|
259
|
-
not_last =
|
|
260
|
-
lock =
|
|
259
|
+
not_last = Event()
|
|
260
|
+
lock = Lock()
|
|
261
261
|
ncalls = 0
|
|
262
262
|
|
|
263
263
|
async def wrapper(items: Sequence[T]) -> list[R]:
|
|
@@ -270,10 +270,10 @@ def astreaming[T, R](
|
|
|
270
270
|
not_last.set()
|
|
271
271
|
|
|
272
272
|
ncalls += 1
|
|
273
|
-
fs: list[
|
|
273
|
+
fs: list[Future[R]] = []
|
|
274
274
|
try:
|
|
275
275
|
for x in items:
|
|
276
|
-
f =
|
|
276
|
+
f = Future[R]()
|
|
277
277
|
fs.append(f)
|
|
278
278
|
buf.append((x, f))
|
|
279
279
|
|
|
@@ -305,3 +305,45 @@ def astreaming[T, R](
|
|
|
305
305
|
return await asyncio.gather(*fs)
|
|
306
306
|
|
|
307
307
|
return wrapper
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
# ----------------------------- read/write guard -----------------------------
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class RwLock:
|
|
314
|
+
"""Guard code from concurrent writes.
|
|
315
|
+
|
|
316
|
+
Reads are not limited.
|
|
317
|
+
When write is issued, new reads are delayed until write is finished.
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(self) -> None:
|
|
321
|
+
self._num_reads = 0
|
|
322
|
+
self._readable = Event()
|
|
323
|
+
self._readable.set()
|
|
324
|
+
self._writable = Event()
|
|
325
|
+
self._writable.set()
|
|
326
|
+
|
|
327
|
+
@asynccontextmanager
|
|
328
|
+
async def read(self) -> AsyncIterator[None]:
|
|
329
|
+
await self._readable.wait()
|
|
330
|
+
self._writable.clear()
|
|
331
|
+
try:
|
|
332
|
+
yield
|
|
333
|
+
finally:
|
|
334
|
+
self._num_reads -= 1
|
|
335
|
+
if self._num_reads == 0:
|
|
336
|
+
self._writable.set()
|
|
337
|
+
|
|
338
|
+
@asynccontextmanager
|
|
339
|
+
async def write(self) -> AsyncIterator[None]:
|
|
340
|
+
self._readable.clear() # Stop new READs
|
|
341
|
+
try:
|
|
342
|
+
await self._writable.wait() # Wait for all READs or single WRITE
|
|
343
|
+
self._writable.clear() # Only single WRITE is allowed
|
|
344
|
+
try:
|
|
345
|
+
yield
|
|
346
|
+
finally:
|
|
347
|
+
self._writable.set()
|
|
348
|
+
finally:
|
|
349
|
+
self._readable.set()
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from collections.abc import AsyncIterator, Callable, Mapping
|
|
2
|
+
from contextlib import AbstractAsyncContextManager
|
|
2
3
|
from typing import Any, Required, TypedDict, Unpack, overload
|
|
3
4
|
|
|
4
5
|
from ._futures import ABatchDecorator, ABatchFn
|
|
@@ -107,3 +108,8 @@ def astreaming[T, R](
|
|
|
107
108
|
batch_size: int | None = ...,
|
|
108
109
|
timeout: float = ...,
|
|
109
110
|
) -> ABatchFn[T, R]: ...
|
|
111
|
+
|
|
112
|
+
class RwLock:
|
|
113
|
+
def __init__(self) -> None: ...
|
|
114
|
+
def read(self) -> AbstractAsyncContextManager: ...
|
|
115
|
+
def write(self) -> AbstractAsyncContextManager: ...
|
|
@@ -9,6 +9,7 @@ __all__ = [
|
|
|
9
9
|
|
|
10
10
|
import atexit
|
|
11
11
|
import enum
|
|
12
|
+
import logging
|
|
12
13
|
import os
|
|
13
14
|
import signal
|
|
14
15
|
import sys
|
|
@@ -20,7 +21,6 @@ from contextlib import ExitStack, contextmanager
|
|
|
20
21
|
from cProfile import Profile
|
|
21
22
|
from functools import partial
|
|
22
23
|
from itertools import chain, islice, starmap
|
|
23
|
-
from logging import getLogger
|
|
24
24
|
from multiprocessing.managers import BaseManager
|
|
25
25
|
from operator import methodcaller
|
|
26
26
|
from pstats import Stats
|
|
@@ -42,6 +42,8 @@ from ._reduction import move_to_shmem, reducers
|
|
|
42
42
|
from ._thread_quota import ThreadQuota
|
|
43
43
|
from ._types import Get, Some
|
|
44
44
|
|
|
45
|
+
_LOGGER = logging.getLogger(__name__)
|
|
46
|
+
|
|
45
47
|
_TOTAL_CPUS = (
|
|
46
48
|
os.process_cpu_count() if sys.version_info >= (3, 13) else os.cpu_count()
|
|
47
49
|
)
|
|
@@ -50,9 +52,11 @@ _NUM_CPUS = _TOTAL_CPUS or 0
|
|
|
50
52
|
if (_env_cpus := os.getenv('GLOW_CPUS')) is not None:
|
|
51
53
|
_NUM_CPUS = min(_NUM_CPUS, int(_env_cpus))
|
|
52
54
|
_NUM_CPUS = max(_NUM_CPUS, 0)
|
|
55
|
+
|
|
53
56
|
_IDLE_WORKER_TIMEOUT = 10
|
|
54
|
-
|
|
55
|
-
|
|
57
|
+
# TODO: investigate whether this improves load
|
|
58
|
+
_FAST_GROW = False
|
|
59
|
+
_GRANULAR_SCHEDULING = False
|
|
56
60
|
|
|
57
61
|
|
|
58
62
|
class _Empty(enum.Enum):
|
|
@@ -79,19 +83,16 @@ class _Manager(Protocol):
|
|
|
79
83
|
def Queue(self, /, maxsize: int) -> _Queue: ... # noqa: N802
|
|
80
84
|
|
|
81
85
|
|
|
82
|
-
def
|
|
83
|
-
upper_bound: int = sys.maxsize, *, mp: bool = False
|
|
84
|
-
) -> Iterator[int]:
|
|
85
|
-
yield from (upper_bound, _TOTAL_CPUS or 1)
|
|
86
|
-
|
|
86
|
+
def _torch_limit() -> int | None:
|
|
87
87
|
# Windows platform lacks memory overcommit, so it's sensitive to VMS growth
|
|
88
|
-
if
|
|
89
|
-
return
|
|
88
|
+
if sys.platform != 'win32':
|
|
89
|
+
return None
|
|
90
90
|
|
|
91
|
-
|
|
91
|
+
torch = sys.modules.get('torch')
|
|
92
|
+
if torch is None or (torch.version.cuda or '') >= '11.7.0':
|
|
92
93
|
# It's expected that torch will fix .nv_fatb readonly flag in its DLLs
|
|
93
94
|
# See https://stackoverflow.com/a/69489193/9868257
|
|
94
|
-
return
|
|
95
|
+
return None
|
|
95
96
|
|
|
96
97
|
if psutil is None:
|
|
97
98
|
warnings.warn(
|
|
@@ -100,16 +101,24 @@ def _get_cpu_count_limits(
|
|
|
100
101
|
'Install psutil to avoid that',
|
|
101
102
|
stacklevel=3,
|
|
102
103
|
)
|
|
103
|
-
return
|
|
104
|
+
return None
|
|
104
105
|
|
|
105
|
-
#
|
|
106
|
+
# Windows has no overcommit, checking how much processes fit into VMS
|
|
106
107
|
vms: int = psutil.Process().memory_info().vms
|
|
107
108
|
free_vms: int = psutil.virtual_memory().free + psutil.swap_memory().free
|
|
108
|
-
|
|
109
|
+
return free_vms // vms
|
|
110
|
+
|
|
109
111
|
|
|
112
|
+
def max_cpu_count(limit: int | None = None, *, mp: bool = False) -> int:
|
|
113
|
+
limits = [_TOTAL_CPUS or 1]
|
|
110
114
|
|
|
111
|
-
|
|
112
|
-
|
|
115
|
+
if limit is not None:
|
|
116
|
+
limits.append(limit)
|
|
117
|
+
|
|
118
|
+
if mp and (torch_limit := _torch_limit()) is not None:
|
|
119
|
+
limits.append(torch_limit)
|
|
120
|
+
|
|
121
|
+
return min(limits)
|
|
113
122
|
|
|
114
123
|
|
|
115
124
|
_PATIENCE = 0.01
|
|
@@ -119,7 +128,7 @@ class _TimeoutCallable[T](Protocol):
|
|
|
119
128
|
def __call__(self, *, timeout: float) -> T: ...
|
|
120
129
|
|
|
121
130
|
|
|
122
|
-
def
|
|
131
|
+
def _retrying[T](f: _TimeoutCallable[T], *exc: type[BaseException]) -> T:
|
|
123
132
|
# See issues
|
|
124
133
|
# https://github.com/dask/dask/pull/2144#issuecomment-290556996
|
|
125
134
|
# https://github.com/dask/dask/pull/2144/files
|
|
@@ -127,7 +136,7 @@ def _retry_call[T](fn: _TimeoutCallable[T], *exc: type[BaseException]) -> T:
|
|
|
127
136
|
# FIXED in py3.15+
|
|
128
137
|
while True:
|
|
129
138
|
try:
|
|
130
|
-
return
|
|
139
|
+
return f(timeout=_PATIENCE)
|
|
131
140
|
except exc:
|
|
132
141
|
sleep(0) # Force switch to another thread to proceed
|
|
133
142
|
|
|
@@ -135,7 +144,7 @@ def _retry_call[T](fn: _TimeoutCallable[T], *exc: type[BaseException]) -> T:
|
|
|
135
144
|
if sys.platform == 'win32':
|
|
136
145
|
|
|
137
146
|
def _exception[T](f: Future[T], /) -> BaseException | None:
|
|
138
|
-
return
|
|
147
|
+
return _retrying(f.exception, TimeoutError)
|
|
139
148
|
|
|
140
149
|
else:
|
|
141
150
|
_exception = Future.exception
|
|
@@ -153,7 +162,7 @@ def _result[T](f: Future[T], cancel: bool = True) -> Some[T] | BaseException:
|
|
|
153
162
|
def _q_get_fn[T](q: _Queue[T]) -> Get[T]:
|
|
154
163
|
if sys.platform != 'win32':
|
|
155
164
|
return q.get
|
|
156
|
-
return partial(
|
|
165
|
+
return partial(_retrying, q.get, Empty)
|
|
157
166
|
|
|
158
167
|
|
|
159
168
|
# ---------------------------- pool initialization ----------------------------
|
|
@@ -325,8 +334,8 @@ class _AutoSize:
|
|
|
325
334
|
return # Or duration is less then `monotonic()` precision
|
|
326
335
|
|
|
327
336
|
if self.duration < self.MIN_DURATION: # Too high IPC overhead
|
|
328
|
-
size = self.size * 2
|
|
329
|
-
_LOGGER.debug('
|
|
337
|
+
size = self._new_scale() if _FAST_GROW else self.size * 2
|
|
338
|
+
_LOGGER.debug('Increasing batch size to %d', size)
|
|
330
339
|
|
|
331
340
|
elif (
|
|
332
341
|
self.duration <= self.MAX_DURATION # Range is optimal
|
|
@@ -335,25 +344,35 @@ class _AutoSize:
|
|
|
335
344
|
return
|
|
336
345
|
|
|
337
346
|
else: # Too high latency
|
|
338
|
-
size =
|
|
339
|
-
size = max(size, 1)
|
|
347
|
+
size = self._new_scale()
|
|
340
348
|
_LOGGER.debug('Reducing batch size to %d', size)
|
|
341
349
|
|
|
342
350
|
self.size = size
|
|
343
351
|
self.duration = 0.0
|
|
344
352
|
|
|
353
|
+
def _new_scale(self) -> int:
|
|
354
|
+
factor = 2 * self.MIN_DURATION / self.duration
|
|
355
|
+
factor = min(factor, 32)
|
|
356
|
+
size = int(self.size * factor)
|
|
357
|
+
return max(size, 1)
|
|
358
|
+
|
|
345
359
|
|
|
346
360
|
# ---------------------- map iterable through function ----------------------
|
|
347
361
|
|
|
348
362
|
|
|
349
363
|
def _schedule[F: Future](
|
|
350
|
-
|
|
364
|
+
submit_chunk: Callable[..., F],
|
|
365
|
+
args_zip: Iterable[Iterable],
|
|
366
|
+
chunksize: int,
|
|
351
367
|
) -> Iterator[F]:
|
|
352
|
-
|
|
368
|
+
for chunk in chunked(args_zip, chunksize):
|
|
369
|
+
f = submit_chunk(*chunk)
|
|
370
|
+
_LOGGER.debug('Submit %d', len(chunk))
|
|
371
|
+
yield f
|
|
353
372
|
|
|
354
373
|
|
|
355
374
|
def _schedule_auto[F: Future](
|
|
356
|
-
|
|
375
|
+
submit_chunk: Callable[..., F],
|
|
357
376
|
args_zip: Iterator[Iterable],
|
|
358
377
|
max_workers: int,
|
|
359
378
|
) -> Iterator[F]:
|
|
@@ -361,28 +380,30 @@ def _schedule_auto[F: Future](
|
|
|
361
380
|
size = _AutoSize()
|
|
362
381
|
while tuples := [*islice(args_zip, size.suggest() * max_workers)]:
|
|
363
382
|
chunksize = len(tuples) // max_workers or 1
|
|
364
|
-
for
|
|
365
|
-
f =
|
|
366
|
-
|
|
383
|
+
for chunk in chunked(tuples, chunksize):
|
|
384
|
+
f = submit_chunk(*chunk)
|
|
385
|
+
_LOGGER.debug('Submit %d', len(chunk))
|
|
386
|
+
f.add_done_callback(partial(size.update, len(chunk), monotonic()))
|
|
367
387
|
yield f
|
|
368
388
|
|
|
369
389
|
|
|
370
390
|
def _schedule_auto_v2[F: Future](
|
|
371
|
-
|
|
391
|
+
submit_chunk: Callable[..., F], args_zip: Iterator[Iterable]
|
|
372
392
|
) -> Iterator[F]:
|
|
373
393
|
# Vary job size from future to future
|
|
374
394
|
size = _AutoSize()
|
|
375
|
-
while
|
|
376
|
-
f =
|
|
377
|
-
|
|
395
|
+
while chunk := [*islice(args_zip, size.suggest())]:
|
|
396
|
+
f = submit_chunk(*chunk)
|
|
397
|
+
_LOGGER.debug('Submit %d', len(chunk))
|
|
398
|
+
f.add_done_callback(partial(size.update, len(chunk), monotonic()))
|
|
378
399
|
yield f
|
|
379
400
|
|
|
380
401
|
|
|
381
402
|
def _get_unwrap_iter[T](
|
|
382
403
|
s: ExitStack,
|
|
383
404
|
qsize: int,
|
|
384
|
-
|
|
385
|
-
|
|
405
|
+
get_f: Get[Future[T]],
|
|
406
|
+
sched_iter: Iterator,
|
|
386
407
|
) -> Iterator[T]:
|
|
387
408
|
with s:
|
|
388
409
|
if not qsize: # No tasks to do
|
|
@@ -390,45 +411,47 @@ def _get_unwrap_iter[T](
|
|
|
390
411
|
|
|
391
412
|
# Unwrap 1st / schedule `N-qsize` / unwrap `qsize-1`
|
|
392
413
|
with hide_frame:
|
|
393
|
-
for _ in chain([None],
|
|
414
|
+
for _ in chain([None], sched_iter, range(qsize - 1)):
|
|
394
415
|
# Retrieve done task, exactly `N` calls
|
|
395
|
-
obj = _result(
|
|
416
|
+
obj = _result(get_f())
|
|
396
417
|
if not isinstance(obj, Some):
|
|
397
418
|
with hide_frame:
|
|
398
419
|
raise obj
|
|
399
420
|
yield obj.x
|
|
400
421
|
|
|
401
422
|
|
|
402
|
-
def
|
|
403
|
-
s: ExitStack,
|
|
423
|
+
def _enqueue[T](
|
|
404
424
|
fs: Iterable[Future[T]],
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
order: bool,
|
|
408
|
-
) -> Iterator[T]:
|
|
425
|
+
unordered: bool,
|
|
426
|
+
) -> tuple[Iterator, Get[Future[T]]]:
|
|
409
427
|
q = SimpleQueue[Future[T]]()
|
|
410
428
|
|
|
411
|
-
#
|
|
412
|
-
#
|
|
413
|
-
# FIXME:
|
|
429
|
+
# In `unordered` mode `q` contains only "DONE" tasks,
|
|
430
|
+
# else there are also "PENDING" and "RUNNING" tasks.
|
|
431
|
+
# FIXME: unordered=True -> random freezes (in q.get -> Empty)
|
|
414
432
|
q_put = cast(
|
|
415
433
|
'Callable[[Future[T]], None]',
|
|
416
|
-
q.put if
|
|
434
|
+
methodcaller('add_done_callback', q.put) if unordered else q.put,
|
|
417
435
|
)
|
|
436
|
+
q_get = _q_get_fn(q)
|
|
418
437
|
|
|
419
438
|
# On each `next()` schedules new task
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
439
|
+
sched_iter = map(q_put, fs)
|
|
440
|
+
|
|
441
|
+
return sched_iter, q_get
|
|
442
|
+
|
|
424
443
|
|
|
444
|
+
def _prefetch(s: ExitStack, sched_iter: Iterator, count: int | None) -> int:
|
|
445
|
+
try:
|
|
446
|
+
# Fetch up to `count` tasks to pre-fill `q`
|
|
447
|
+
qsize = ilen(islice(sched_iter, count))
|
|
425
448
|
except BaseException:
|
|
426
449
|
# Unwind stack here on an error
|
|
427
450
|
s.close()
|
|
428
451
|
raise
|
|
429
|
-
|
|
430
452
|
else:
|
|
431
|
-
|
|
453
|
+
_LOGGER.debug('Prefetched %d jobs', qsize)
|
|
454
|
+
return qsize
|
|
432
455
|
|
|
433
456
|
|
|
434
457
|
def _batch_invoke[*Ts, R](
|
|
@@ -446,7 +469,7 @@ def starmap_n[T](
|
|
|
446
469
|
prefetch: int | None = 2,
|
|
447
470
|
mp: bool = False,
|
|
448
471
|
chunksize: int | None = None,
|
|
449
|
-
|
|
472
|
+
unordered: bool = False,
|
|
450
473
|
) -> Iterator[T]:
|
|
451
474
|
"""Equivalent to itertools.starmap(fn, iterable).
|
|
452
475
|
|
|
@@ -455,29 +478,27 @@ def starmap_n[T](
|
|
|
455
478
|
|
|
456
479
|
Options:
|
|
457
480
|
- workers - Count of workers, by default all hardware threads are occupied.
|
|
458
|
-
- prefetch -
|
|
481
|
+
- prefetch - Count of extra jobs to schedule over N workers.
|
|
482
|
+
Helps with CPU stalls in ordered mode.
|
|
483
|
+
Increase if job execution time is highly variable.
|
|
459
484
|
- mp - Whether use processes or threads.
|
|
460
485
|
- chunksize - The size of the chunks the iterable will be broken into
|
|
461
|
-
before being passed to a processes.
|
|
486
|
+
before being passed to a processes.
|
|
487
|
+
Estimated automatically.
|
|
462
488
|
Ignored when threads are used.
|
|
463
|
-
-
|
|
489
|
+
- unordered - Retrieve results in order of completion or in original order.
|
|
490
|
+
In this mode `prefetch` is meaningless, because when some job became done
|
|
491
|
+
it yielded immediately releasing buffer for new job to schedule.
|
|
492
|
+
So no CPU stalls.
|
|
464
493
|
|
|
465
494
|
Unlike multiprocessing.Pool or concurrent.futures.Executor this one:
|
|
466
495
|
- never deadlocks on any exception or Ctrl-C interruption.
|
|
467
|
-
- accepts infinite iterables due to lazy task creation
|
|
496
|
+
- accepts infinite iterables due to lazy task creation.
|
|
468
497
|
- has single interface for both threads and processes.
|
|
469
498
|
- TODO: serializes array-like data using out-of-band Pickle 5 buffers.
|
|
470
|
-
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
Notes:
|
|
474
|
-
- To reduce latency set order to False, order of results will be arbitrary.
|
|
475
|
-
- To increase CPU usage increase prefetch or set it to None.
|
|
476
|
-
- In terms of CPU usage there's no difference between
|
|
477
|
-
prefetch=None and order=False, so choose wisely.
|
|
478
|
-
- Setting order to False makes no use of prefetch more than 0.
|
|
479
|
-
|
|
480
|
-
TODO: replace `order=True` with `heap=False`
|
|
499
|
+
- call immediately creates pool ready to yield results
|
|
500
|
+
(which could take some time cause of serialization for multiprocessing),
|
|
501
|
+
so first `__next__` runs on warmed up pool.
|
|
481
502
|
"""
|
|
482
503
|
if max_workers is None:
|
|
483
504
|
max_workers = max_cpu_count(_NUM_CPUS, mp=mp)
|
|
@@ -489,7 +510,9 @@ def starmap_n[T](
|
|
|
489
510
|
msg = 'With multiprocessing either chunksize or prefetch should be set'
|
|
490
511
|
raise ValueError(msg)
|
|
491
512
|
|
|
492
|
-
if
|
|
513
|
+
if unordered:
|
|
514
|
+
prefetch = max(max_workers, 1)
|
|
515
|
+
elif prefetch is not None:
|
|
493
516
|
prefetch = max(prefetch + max_workers, 1)
|
|
494
517
|
|
|
495
518
|
it = iter(iterable)
|
|
@@ -502,24 +525,28 @@ def starmap_n[T](
|
|
|
502
525
|
chunksize = chunksize or 1
|
|
503
526
|
|
|
504
527
|
if chunksize == 1:
|
|
505
|
-
|
|
506
|
-
f1s = starmap(
|
|
507
|
-
|
|
528
|
+
submit_1 = cast('Callable[..., Future[T]]', partial(submit, func))
|
|
529
|
+
f1s = starmap(submit_1, it)
|
|
530
|
+
sched1_iter, get_f = _enqueue(f1s, unordered)
|
|
531
|
+
qsize = _prefetch(s, sched1_iter, prefetch)
|
|
532
|
+
return _get_unwrap_iter(s, qsize, get_f, sched1_iter)
|
|
508
533
|
|
|
509
|
-
|
|
534
|
+
submit_n = cast(
|
|
510
535
|
'Callable[..., Future[list[T]]]', partial(submit, _batch_invoke, func)
|
|
511
536
|
)
|
|
512
537
|
if chunksize is not None:
|
|
513
538
|
# Fixed chunksize
|
|
514
|
-
fs = _schedule(
|
|
539
|
+
fs = _schedule(submit_n, it, chunksize)
|
|
515
540
|
elif not _GRANULAR_SCHEDULING:
|
|
516
541
|
# Dynamic chunksize scaling, submit tasks in waves
|
|
517
|
-
fs = _schedule_auto(
|
|
542
|
+
fs = _schedule_auto(submit_n, it, max_workers)
|
|
518
543
|
else:
|
|
519
544
|
# Dynamic chunksize scaling
|
|
520
|
-
fs = _schedule_auto_v2(
|
|
545
|
+
fs = _schedule_auto_v2(submit_n, it)
|
|
521
546
|
|
|
522
|
-
|
|
547
|
+
sched_iter, get_fs = _enqueue(fs, unordered)
|
|
548
|
+
qsize = _prefetch(s, sched_iter, prefetch)
|
|
549
|
+
chunks = _get_unwrap_iter(s, qsize, get_fs, sched_iter)
|
|
523
550
|
return chain.from_iterable(chunks)
|
|
524
551
|
|
|
525
552
|
|
|
@@ -531,7 +558,7 @@ def map_n[T](
|
|
|
531
558
|
prefetch: int | None = 2,
|
|
532
559
|
mp: bool = False,
|
|
533
560
|
chunksize: int | None = None,
|
|
534
|
-
|
|
561
|
+
unordered: bool = False,
|
|
535
562
|
) -> Iterator[T]:
|
|
536
563
|
"""Return iterator equivalent to map(func, *iterables).
|
|
537
564
|
|
|
@@ -547,7 +574,7 @@ def map_n[T](
|
|
|
547
574
|
prefetch=prefetch,
|
|
548
575
|
mp=mp,
|
|
549
576
|
chunksize=chunksize,
|
|
550
|
-
|
|
577
|
+
unordered=unordered,
|
|
551
578
|
)
|
|
552
579
|
|
|
553
580
|
|
|
@@ -2,7 +2,6 @@ __all__ = ['move_to_shmem']
|
|
|
2
2
|
|
|
3
3
|
import copyreg
|
|
4
4
|
import io
|
|
5
|
-
import logging
|
|
6
5
|
import mmap
|
|
7
6
|
import os
|
|
8
7
|
import pickle
|
|
@@ -25,10 +24,6 @@ _SYSTEM_TEMP = Path(tempfile.gettempdir())
|
|
|
25
24
|
reducers: dict[type, Callable] = {}
|
|
26
25
|
loky.set_loky_pickler('pickle')
|
|
27
26
|
|
|
28
|
-
logger = logging.getLogger(__name__)
|
|
29
|
-
# logger.setLevel(logging.DEBUG)
|
|
30
|
-
# logger.addHandler(logging.StreamHandler())
|
|
31
|
-
|
|
32
27
|
|
|
33
28
|
def _get_shm_dir() -> Path:
|
|
34
29
|
if sys.platform != 'win32':
|
|
@@ -8,6 +8,7 @@
|
|
|
8
8
|
__all__ = ['ThreadQuota']
|
|
9
9
|
|
|
10
10
|
import os
|
|
11
|
+
import sys
|
|
11
12
|
from collections import deque
|
|
12
13
|
from collections.abc import Callable
|
|
13
14
|
from concurrent.futures import Executor, Future
|
|
@@ -18,6 +19,13 @@ from threading import _register_atexit # type: ignore[attr-defined]
|
|
|
18
19
|
from threading import Lock, Thread
|
|
19
20
|
from weakref import WeakSet
|
|
20
21
|
|
|
22
|
+
if sys.version_info >= (3, 14):
|
|
23
|
+
from concurrent.futures.thread import WorkerContext
|
|
24
|
+
|
|
25
|
+
_worker_ctx = WorkerContext(lambda: None, ())
|
|
26
|
+
else:
|
|
27
|
+
_worker_ctx = None
|
|
28
|
+
|
|
21
29
|
# TODO: investigate hangups when _TIMEOUT <= .01
|
|
22
30
|
_TIMEOUT = 1
|
|
23
31
|
_MIN_IDLE = os.cpu_count() or 1
|
|
@@ -64,7 +72,10 @@ def _worker(q: _Pipe) -> None:
|
|
|
64
72
|
try:
|
|
65
73
|
while executor := _safe_call(q.get, timeout=_TIMEOUT):
|
|
66
74
|
while work_item := _safe_call(executor._work_queue.popleft):
|
|
67
|
-
|
|
75
|
+
if sys.version_info >= (3, 14):
|
|
76
|
+
work_item.run(_worker_ctx) # Process task
|
|
77
|
+
else:
|
|
78
|
+
work_item.run()
|
|
68
79
|
if _shutdown:
|
|
69
80
|
executor._shutdown = True
|
|
70
81
|
return
|
|
@@ -114,7 +125,10 @@ class ThreadQuota(Executor):
|
|
|
114
125
|
msg = 'cannot schedule futures after shutdown'
|
|
115
126
|
raise RuntimeError(msg)
|
|
116
127
|
|
|
117
|
-
|
|
128
|
+
if sys.version_info >= (3, 14):
|
|
129
|
+
self._work_queue.append(_WorkItem(f, (fn, args, kwargs)))
|
|
130
|
+
else:
|
|
131
|
+
self._work_queue.append(_WorkItem(f, fn, args, kwargs))
|
|
118
132
|
|
|
119
133
|
if _safe_call(self._idle.pop): # Pool is not maximized yet
|
|
120
134
|
if q := _safe_call(_idle.pop): # Use idle worker
|
|
@@ -48,7 +48,7 @@ import sys
|
|
|
48
48
|
import types
|
|
49
49
|
from argparse import ArgumentParser, BooleanOptionalAction, _ArgumentGroup
|
|
50
50
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
51
|
-
from dataclasses import MISSING, Field, field, fields, is_dataclass
|
|
51
|
+
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass
|
|
52
52
|
from inspect import getmodule, signature, stack
|
|
53
53
|
from typing import (
|
|
54
54
|
Any,
|
|
@@ -61,9 +61,26 @@ from typing import (
|
|
|
61
61
|
get_type_hints,
|
|
62
62
|
)
|
|
63
63
|
|
|
64
|
+
from typing_inspection.introspection import (
|
|
65
|
+
UNKNOWN,
|
|
66
|
+
AnnotationSource,
|
|
67
|
+
inspect_annotation,
|
|
68
|
+
)
|
|
69
|
+
|
|
64
70
|
type _Node = str | tuple[str, type, list['_Node']]
|
|
65
71
|
|
|
66
72
|
|
|
73
|
+
@dataclass(kw_only=True)
|
|
74
|
+
class Meta:
|
|
75
|
+
help: str = ''
|
|
76
|
+
flag: str | None = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclass(kw_only=True)
|
|
80
|
+
class _Meta(Meta):
|
|
81
|
+
name: str
|
|
82
|
+
|
|
83
|
+
|
|
67
84
|
def arg(
|
|
68
85
|
default=MISSING,
|
|
69
86
|
/,
|
|
@@ -156,19 +173,36 @@ def _get_fields(fn: Callable) -> Iterator[Field]:
|
|
|
156
173
|
yield fd
|
|
157
174
|
|
|
158
175
|
|
|
176
|
+
def _get_metadata(tp: type, fd: Field) -> tuple[type, _Meta]:
|
|
177
|
+
info = inspect_annotation(tp, annotation_source=AnnotationSource.CLASS)
|
|
178
|
+
|
|
179
|
+
flag = fd.metadata.get('flag')
|
|
180
|
+
name = fd.name.replace('_', '-')
|
|
181
|
+
help_ = fd.metadata.get('help') or ''
|
|
182
|
+
|
|
183
|
+
if info.type is not UNKNOWN:
|
|
184
|
+
tp = info.type
|
|
185
|
+
for m in info.metadata:
|
|
186
|
+
if isinstance(m, Meta):
|
|
187
|
+
help_ = m.help
|
|
188
|
+
flag = m.flag
|
|
189
|
+
|
|
190
|
+
return tp, _Meta(help=help_, flag=flag, name=name)
|
|
191
|
+
|
|
192
|
+
|
|
159
193
|
def _visit_nested(
|
|
160
194
|
parser: ArgumentParser | _ArgumentGroup,
|
|
161
195
|
fn: Callable,
|
|
162
196
|
seen: dict[str, list],
|
|
163
197
|
) -> list[_Node]:
|
|
164
198
|
try:
|
|
165
|
-
hints = get_type_hints(fn)
|
|
199
|
+
hints = get_type_hints(fn, include_extras=True)
|
|
166
200
|
except NameError:
|
|
167
201
|
if fn.__module__ != '__main__':
|
|
168
202
|
raise
|
|
169
203
|
for finfo in stack():
|
|
170
|
-
if not getmodule(finfo.frame):
|
|
171
|
-
hints = get_type_hints(fn,
|
|
204
|
+
if not getmodule(f := finfo.frame):
|
|
205
|
+
hints = get_type_hints(fn, f.f_globals, include_extras=True)
|
|
172
206
|
break
|
|
173
207
|
else:
|
|
174
208
|
raise
|
|
@@ -196,11 +230,11 @@ def _visit_field(
|
|
|
196
230
|
fd: Field,
|
|
197
231
|
seen: dict[str, list],
|
|
198
232
|
) -> _Node:
|
|
233
|
+
tp, meta = _get_metadata(tp, fd)
|
|
199
234
|
cls, opts = _unwrap_type(tp)
|
|
200
235
|
|
|
201
|
-
help_ = fd.metadata.get('help') or ''
|
|
202
236
|
if cls is not bool and fd.default is not MISSING:
|
|
203
|
-
|
|
237
|
+
meta.help += f' (default: {fd.default})'
|
|
204
238
|
|
|
205
239
|
if is_dataclass(cls): # Nested dataclass
|
|
206
240
|
arg_group = parser.add_argument_group(fd.name)
|
|
@@ -218,9 +252,7 @@ def _visit_field(
|
|
|
218
252
|
)
|
|
219
253
|
raise TypeError(msg)
|
|
220
254
|
|
|
221
|
-
|
|
222
|
-
flags = [f] if (f := fd.metadata.get('flag')) else []
|
|
223
|
-
|
|
255
|
+
flags = [meta.flag] if meta.flag else []
|
|
224
256
|
default = (
|
|
225
257
|
fd.default if fd.default_factory is MISSING else fd.default_factory()
|
|
226
258
|
)
|
|
@@ -230,11 +262,11 @@ def _visit_field(
|
|
|
230
262
|
msg = f'Boolean field "{fd.name}" should have default'
|
|
231
263
|
raise ValueError(msg)
|
|
232
264
|
parser.add_argument(
|
|
233
|
-
f'--{
|
|
265
|
+
f'--{meta.name}',
|
|
234
266
|
*flags,
|
|
235
267
|
action=BooleanOptionalAction,
|
|
236
268
|
default=default,
|
|
237
|
-
help=
|
|
269
|
+
help=meta.help,
|
|
238
270
|
)
|
|
239
271
|
|
|
240
272
|
# Generic optional
|
|
@@ -242,14 +274,14 @@ def _visit_field(
|
|
|
242
274
|
if opts.get('nargs') == argparse.OPTIONAL:
|
|
243
275
|
del opts['nargs']
|
|
244
276
|
parser.add_argument(
|
|
245
|
-
f'--{
|
|
277
|
+
f'--{meta.name}', *flags, **opts, default=default, help=meta.help
|
|
246
278
|
)
|
|
247
279
|
|
|
248
280
|
elif isinstance(parser, ArgumentParser): # Allow only for root parser
|
|
249
|
-
if
|
|
281
|
+
if meta.flag:
|
|
250
282
|
msg = f'Positional-only field "{fd.name}" should not have flag'
|
|
251
283
|
raise ValueError(msg)
|
|
252
|
-
parser.add_argument(
|
|
284
|
+
parser.add_argument(meta.name, **opts, help=meta.help)
|
|
253
285
|
|
|
254
286
|
else:
|
|
255
287
|
msg = (
|
|
@@ -1,9 +1,15 @@
|
|
|
1
1
|
from argparse import ArgumentParser
|
|
2
2
|
from collections.abc import Callable, Mapping, Sequence
|
|
3
|
+
from dataclasses import dataclass
|
|
3
4
|
from typing import Any, overload
|
|
4
5
|
|
|
5
6
|
from ._types import Get
|
|
6
7
|
|
|
8
|
+
@dataclass
|
|
9
|
+
class Meta:
|
|
10
|
+
help: str = ...
|
|
11
|
+
flag: str | None = ...
|
|
12
|
+
|
|
7
13
|
@overload
|
|
8
14
|
def arg[T](
|
|
9
15
|
default: T,
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
from collections.abc import Callable
|
|
2
2
|
from dataclasses import dataclass
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Literal
|
|
4
|
+
from typing import Annotated, Any, Literal
|
|
5
5
|
|
|
6
6
|
import pytest
|
|
7
7
|
|
|
8
|
-
from glow.cli import parse_args
|
|
8
|
+
from glow.cli import Meta, parse_args
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
@dataclass
|
|
@@ -81,6 +81,21 @@ class Custom:
|
|
|
81
81
|
arg: Path
|
|
82
82
|
|
|
83
83
|
|
|
84
|
+
@dataclass
|
|
85
|
+
class AnnotatedPositional:
|
|
86
|
+
arg: Annotated[int, Meta(help='help')]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@dataclass
|
|
90
|
+
class FlagKeyword:
|
|
91
|
+
param: Annotated[int, Meta(help='help', flag='-p')] = 42
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class AnnotatedKeyword:
|
|
96
|
+
param: Annotated[int, Meta(help='help')] = 42
|
|
97
|
+
|
|
98
|
+
|
|
84
99
|
@pytest.mark.parametrize(
|
|
85
100
|
('argv', 'expected'),
|
|
86
101
|
[
|
|
@@ -98,6 +113,10 @@ class Custom:
|
|
|
98
113
|
(['value'], Nested('value', Optional_())),
|
|
99
114
|
(['value', '--param', 'pvalue'], Nested('value', Optional_('pvalue'))),
|
|
100
115
|
(['test.txt'], Custom(Path('test.txt'))),
|
|
116
|
+
(['5'], AnnotatedPositional(5)),
|
|
117
|
+
(['--param', '5'], AnnotatedKeyword(5)),
|
|
118
|
+
([], AnnotatedKeyword(42)),
|
|
119
|
+
(['-p', '42'], FlagKeyword(42)),
|
|
101
120
|
],
|
|
102
121
|
)
|
|
103
122
|
def test_good_class(argv: list[str], expected: Any):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|