wandb 0.19.6rc4__py3-none-win_amd64.whl → 0.19.8__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +56 -6
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +11 -20
- wandb/data_types.py +1 -1
- wandb/env.py +10 -0
- wandb/filesync/dir_watcher.py +2 -1
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +10 -10
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +10 -10
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +51 -95
- wandb/sdk/backend/backend.py +17 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +87 -49
- wandb/sdk/interface/interface_queue.py +5 -15
- wandb/sdk/interface/interface_relay.py +7 -22
- wandb/sdk/interface/interface_shared.py +65 -136
- wandb/sdk/interface/interface_sock.py +3 -21
- wandb/sdk/interface/router.py +42 -68
- wandb/sdk/interface/router_queue.py +13 -11
- wandb/sdk/interface/router_relay.py +26 -13
- wandb/sdk/interface/router_sock.py +12 -16
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +3 -19
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/console_capture.py +172 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/redirect.py +102 -76
- wandb/sdk/lib/service_connection.py +37 -17
- wandb/sdk/lib/sock_client.py +6 -56
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/mailbox.py +135 -0
- wandb/sdk/mailbox/mailbox_handle.py +127 -0
- wandb/sdk/mailbox/response_handle.py +167 -0
- wandb/sdk/mailbox/wait_with_progress.py +135 -0
- wandb/sdk/service/server_sock.py +9 -3
- wandb/sdk/service/streams.py +75 -78
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +72 -75
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +90 -97
- wandb/sdk/wandb_settings.py +10 -4
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -10
- wandb/util.py +3 -1
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/METADATA +2 -2
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/RECORD +79 -66
- wandb/sdk/interface/message_future.py +0 -27
- wandb/sdk/interface/message_future_poll.py +0 -50
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/WHEEL +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.8.dist-info}/licenses/LICENSE +0 -0
@@ -4,14 +4,14 @@ Router to manage responses from a socket client.
|
|
4
4
|
|
5
5
|
"""
|
6
6
|
|
7
|
-
from
|
7
|
+
from __future__ import annotations
|
8
8
|
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from .
|
9
|
+
from wandb.proto import wandb_internal_pb2 as pb
|
10
|
+
from wandb.proto import wandb_server_pb2 as spb
|
11
|
+
from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
|
12
|
+
from wandb.sdk.mailbox import Mailbox
|
12
13
|
|
13
|
-
|
14
|
-
from wandb.proto import wandb_internal_pb2 as pb
|
14
|
+
from .router import MessageRouter, MessageRouterClosedError
|
15
15
|
|
16
16
|
|
17
17
|
class MessageSockRouter(MessageRouter):
|
@@ -22,15 +22,11 @@ class MessageSockRouter(MessageRouter):
|
|
22
22
|
self._sock_client = sock_client
|
23
23
|
super().__init__(mailbox=mailbox)
|
24
24
|
|
25
|
-
def _read_message(self) ->
|
25
|
+
def _read_message(self) -> spb.ServerResponse | None:
|
26
26
|
try:
|
27
|
-
|
28
|
-
except SockClientClosedError:
|
29
|
-
raise MessageRouterClosedError
|
30
|
-
|
31
|
-
|
32
|
-
msg = resp.result_communicate
|
33
|
-
return msg
|
34
|
-
|
35
|
-
def _send_message(self, record: "pb.Record") -> None:
|
27
|
+
return self._sock_client.read_server_response(timeout=1)
|
28
|
+
except SockClientClosedError as e:
|
29
|
+
raise MessageRouterClosedError from e
|
30
|
+
|
31
|
+
def _send_message(self, record: pb.Record) -> None:
|
36
32
|
self._sock_client.send_record_communicate(record)
|
wandb/sdk/internal/handler.py
CHANGED
@@ -205,9 +205,6 @@ class HandleManager:
|
|
205
205
|
# defer is used to drive the sender finish state machine
|
206
206
|
self._dispatch_record(record, always_send=True)
|
207
207
|
|
208
|
-
def handle_request_login(self, record: Record) -> None:
|
209
|
-
self._dispatch_record(record)
|
210
|
-
|
211
208
|
def handle_request_python_packages(self, record: Record) -> None:
|
212
209
|
self._dispatch_record(record)
|
213
210
|
|
@@ -892,6 +889,10 @@ class HandleManager:
|
|
892
889
|
self._respond_result(result)
|
893
890
|
self._stopped.set()
|
894
891
|
|
892
|
+
def handle_request_operations(self, record: Record) -> None:
|
893
|
+
"""No-op. Not implemented for the legacy-service."""
|
894
|
+
self._respond_result(proto_util._result_from_record(record))
|
895
|
+
|
895
896
|
def finish(self) -> None:
|
896
897
|
logger.info("shutting down handler")
|
897
898
|
if self._system_monitor is not None:
|
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
|
|
115
115
|
root_dir: Optional[str]
|
116
116
|
api_key: Optional[str]
|
117
117
|
entity: Optional[str]
|
118
|
+
organization: Optional[str]
|
118
119
|
project: Optional[str]
|
119
120
|
_extra_http_headers: Optional[Mapping[str, str]]
|
120
121
|
_proxies: Optional[Mapping[str, str]]
|
@@ -256,6 +257,7 @@ class Api:
|
|
256
257
|
"root_dir": None,
|
257
258
|
"api_key": None,
|
258
259
|
"entity": None,
|
260
|
+
"organization": None,
|
259
261
|
"project": None,
|
260
262
|
"_extra_http_headers": None,
|
261
263
|
"_proxies": None,
|
@@ -489,7 +491,8 @@ class Api:
|
|
489
491
|
{
|
490
492
|
"entity": "models",
|
491
493
|
"base_url": "https://api.wandb.ai",
|
492
|
-
"project": None
|
494
|
+
"project": None,
|
495
|
+
"organization": "my-org",
|
493
496
|
}
|
494
497
|
"""
|
495
498
|
result = self.default_settings.copy()
|
@@ -504,6 +507,14 @@ class Api:
|
|
504
507
|
),
|
505
508
|
env=self._environ,
|
506
509
|
),
|
510
|
+
"organization": env.get_organization(
|
511
|
+
self._settings.get(
|
512
|
+
Settings.DEFAULT_SECTION,
|
513
|
+
"organization",
|
514
|
+
fallback=result.get("organization"),
|
515
|
+
),
|
516
|
+
env=self._environ,
|
517
|
+
),
|
507
518
|
"project": env.get_project(
|
508
519
|
self._settings.get(
|
509
520
|
Settings.DEFAULT_SECTION,
|
wandb/sdk/internal/sender.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""sender."""
|
2
2
|
|
3
3
|
import contextlib
|
4
|
+
import glob
|
4
5
|
import gzip
|
5
6
|
import json
|
6
7
|
import logging
|
@@ -544,23 +545,6 @@ class SendManager:
|
|
544
545
|
logger.warning(f"Error emptying retry queue: {e}")
|
545
546
|
self._respond_result(result)
|
546
547
|
|
547
|
-
def send_request_login(self, record: "Record") -> None:
|
548
|
-
# TODO: do something with api_key or anonymous?
|
549
|
-
# TODO: return an error if we aren't logged in?
|
550
|
-
self._api.reauth()
|
551
|
-
viewer = self.get_viewer_info()
|
552
|
-
server_info = self.get_server_info()
|
553
|
-
# self._login_flags = json.loads(viewer.get("flags", "{}"))
|
554
|
-
# self._login_entity = viewer.get("entity")
|
555
|
-
if server_info:
|
556
|
-
logger.info(f"Login server info: {server_info}")
|
557
|
-
self._entity = viewer.get("entity")
|
558
|
-
if record.control.req_resp:
|
559
|
-
result = proto_util._result_from_record(record)
|
560
|
-
if self._entity:
|
561
|
-
result.response.login_response.active_entity = self._entity
|
562
|
-
self._respond_result(result)
|
563
|
-
|
564
548
|
def send_exit(self, record: "Record") -> None:
|
565
549
|
# track where the exit came from
|
566
550
|
self._record_exit = record
|
@@ -1425,7 +1409,8 @@ class SendManager:
|
|
1425
1409
|
for k in files.files:
|
1426
1410
|
# TODO(jhr): fix paths with directories
|
1427
1411
|
self._save_file(
|
1428
|
-
interface.GlobStr(k.path),
|
1412
|
+
interface.GlobStr(glob.escape(k.path)),
|
1413
|
+
interface.file_enum_to_policy(k.policy),
|
1429
1414
|
)
|
1430
1415
|
|
1431
1416
|
def send_header(self, record: "Record") -> None:
|
@@ -1491,7 +1476,6 @@ class SendManager:
|
|
1491
1476
|
self._job_builder.set_partial_source_id(use.id)
|
1492
1477
|
|
1493
1478
|
def send_request_log_artifact(self, record: "Record") -> None:
|
1494
|
-
assert record.control.req_resp
|
1495
1479
|
result = proto_util._result_from_record(record)
|
1496
1480
|
artifact = record.request.log_artifact.artifact
|
1497
1481
|
history_step = record.request.log_artifact.history_step
|
wandb/sdk/lib/apikey.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
1
|
"""apikey util."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import dataclasses
|
3
6
|
import os
|
4
7
|
import platform
|
5
8
|
import stat
|
@@ -32,8 +35,21 @@ LOGIN_CHOICES = [
|
|
32
35
|
LOGIN_CHOICE_DRYRUN,
|
33
36
|
]
|
34
37
|
|
38
|
+
|
39
|
+
@dataclasses.dataclass(frozen=True)
|
40
|
+
class _NetrcPermissions:
|
41
|
+
exists: bool
|
42
|
+
read_access: bool
|
43
|
+
write_access: bool
|
44
|
+
|
45
|
+
|
46
|
+
class WriteNetrcError(Exception):
|
47
|
+
"""Raised when we cannot write to the netrc file."""
|
48
|
+
|
49
|
+
|
35
50
|
Mode = Literal["allow", "must", "never", "false", "true"]
|
36
51
|
|
52
|
+
|
37
53
|
if TYPE_CHECKING:
|
38
54
|
from wandb.sdk.wandb_settings import Settings
|
39
55
|
|
@@ -170,29 +186,64 @@ def prompt_api_key( # noqa: C901
|
|
170
186
|
return key
|
171
187
|
|
172
188
|
|
173
|
-
def
|
189
|
+
def check_netrc_access(
|
190
|
+
netrc_path: str,
|
191
|
+
) -> _NetrcPermissions:
|
192
|
+
"""Check if we can read and write to the netrc file."""
|
193
|
+
file_exists = False
|
194
|
+
write_access = False
|
195
|
+
read_access = False
|
196
|
+
try:
|
197
|
+
st = os.stat(netrc_path)
|
198
|
+
file_exists = True
|
199
|
+
write_access = bool(st.st_mode & stat.S_IWUSR)
|
200
|
+
read_access = bool(st.st_mode & stat.S_IRUSR)
|
201
|
+
except FileNotFoundError:
|
202
|
+
# If the netrc file doesn't exist, we will create it.
|
203
|
+
write_access = True
|
204
|
+
read_access = True
|
205
|
+
except OSError as e:
|
206
|
+
wandb.termerror(f"Unable to read permissions for {netrc_path}, {e}")
|
207
|
+
|
208
|
+
return _NetrcPermissions(
|
209
|
+
exists=file_exists,
|
210
|
+
write_access=write_access,
|
211
|
+
read_access=read_access,
|
212
|
+
)
|
213
|
+
|
214
|
+
|
215
|
+
def write_netrc(host: str, entity: str, key: str):
|
174
216
|
"""Add our host and key to .netrc."""
|
175
217
|
_, key_suffix = key.split("-", 1) if "-" in key else ("", key)
|
176
218
|
if len(key_suffix) != 40:
|
177
|
-
|
219
|
+
raise ValueError(
|
178
220
|
"API-key must be exactly 40 characters long: {} ({} chars)".format(
|
179
221
|
key_suffix, len(key_suffix)
|
180
222
|
)
|
181
223
|
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
224
|
+
|
225
|
+
normalized_host = urlparse(host).netloc
|
226
|
+
netrc_path = get_netrc_file_path()
|
227
|
+
netrc_access = check_netrc_access(netrc_path)
|
228
|
+
|
229
|
+
if not netrc_access.write_access or not netrc_access.read_access:
|
230
|
+
raise WriteNetrcError(
|
231
|
+
f"Cannot access {netrc_path}. In order to persist your API key, "
|
232
|
+
"grant read and write permissions for your user to the file "
|
233
|
+
'or specify a different file with the environment variable "NETRC=<new_netrc_path>".'
|
188
234
|
)
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
235
|
+
|
236
|
+
machine_line = f"machine {normalized_host}"
|
237
|
+
orig_lines = None
|
238
|
+
try:
|
239
|
+
with open(netrc_path) as f:
|
240
|
+
orig_lines = f.read().strip().split("\n")
|
241
|
+
except FileNotFoundError:
|
242
|
+
wandb.termlog("No netrc file found, creating one.")
|
243
|
+
except OSError as e:
|
244
|
+
raise WriteNetrcError(f"Unable to read {netrc_path}") from e
|
245
|
+
|
246
|
+
try:
|
196
247
|
with open(netrc_path, "w") as f:
|
197
248
|
if orig_lines:
|
198
249
|
# delete this machine from the file if it's already there.
|
@@ -206,20 +257,22 @@ def write_netrc(host: str, entity: str, key: str) -> Optional[bool]:
|
|
206
257
|
skip -= 1
|
207
258
|
else:
|
208
259
|
f.write("{}\n".format(line))
|
260
|
+
|
261
|
+
wandb.termlog(
|
262
|
+
f"Appending key for {normalized_host} to your netrc file: {netrc_path}"
|
263
|
+
)
|
209
264
|
f.write(
|
210
265
|
textwrap.dedent(
|
211
266
|
"""\
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
267
|
+
machine {host}
|
268
|
+
login {entity}
|
269
|
+
password {key}
|
270
|
+
"""
|
216
271
|
).format(host=normalized_host, entity=entity, key=key)
|
217
272
|
)
|
218
273
|
os.chmod(netrc_path, stat.S_IRUSR | stat.S_IWUSR)
|
219
|
-
|
220
|
-
|
221
|
-
wandb.termerror(f"Unable to read {netrc_path}")
|
222
|
-
return None
|
274
|
+
except OSError as e:
|
275
|
+
raise WriteNetrcError(f"Unable to write {netrc_path}") from e
|
223
276
|
|
224
277
|
|
225
278
|
def write_key(
|
@@ -250,7 +303,15 @@ def api_key(settings: Optional["Settings"] = None) -> Optional[str]:
|
|
250
303
|
settings = wandb.setup().settings
|
251
304
|
if settings.api_key:
|
252
305
|
return settings.api_key
|
253
|
-
|
254
|
-
|
255
|
-
|
306
|
+
|
307
|
+
netrc_access = check_netrc_access(get_netrc_file_path())
|
308
|
+
if netrc_access.exists and not netrc_access.read_access:
|
309
|
+
wandb.termwarn(f"Cannot access {get_netrc_file_path()}.")
|
310
|
+
return None
|
311
|
+
|
312
|
+
if netrc_access.exists:
|
313
|
+
auth = get_netrc_auth(settings.base_url)
|
314
|
+
if auth:
|
315
|
+
return auth[-1]
|
316
|
+
|
256
317
|
return None
|
@@ -0,0 +1,210 @@
|
|
1
|
+
"""Functions for compatibility with asyncio."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import concurrent
|
7
|
+
import concurrent.futures
|
8
|
+
import contextlib
|
9
|
+
import threading
|
10
|
+
from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, TypeVar
|
11
|
+
|
12
|
+
_T = TypeVar("_T")
|
13
|
+
|
14
|
+
|
15
|
+
def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
16
|
+
"""Run `fn` in an asyncio loop in a new thread.
|
17
|
+
|
18
|
+
This must always be used instead of `asyncio.run` which fails if there is
|
19
|
+
an active `asyncio` event loop in the current thread. Since `wandb` was not
|
20
|
+
originally designed with `asyncio` in mind, using `asyncio.run` would break
|
21
|
+
users who were calling `wandb` methods from an `asyncio` loop.
|
22
|
+
|
23
|
+
Note that due to starting a new thread, this is slightly slow.
|
24
|
+
"""
|
25
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
26
|
+
runner = _Runner()
|
27
|
+
future = executor.submit(runner.run, fn)
|
28
|
+
|
29
|
+
try:
|
30
|
+
return future.result()
|
31
|
+
|
32
|
+
finally:
|
33
|
+
runner.cancel()
|
34
|
+
|
35
|
+
|
36
|
+
class _RunnerCancelledError(Exception):
|
37
|
+
"""The `_Runner.run()` invocation was cancelled."""
|
38
|
+
|
39
|
+
|
40
|
+
class _Runner:
|
41
|
+
"""Runs an asyncio event loop allowing cancellation.
|
42
|
+
|
43
|
+
This is like `asyncio.run()`, except it provides a `cancel()` method
|
44
|
+
meant to be called in a `finally` block.
|
45
|
+
|
46
|
+
Without this, it is impossible to make `asyncio.run()` stop if it runs
|
47
|
+
in a non-main thread. In particular, a KeyboardInterrupt causes the
|
48
|
+
ThreadPoolExecutor above to block until the asyncio thread completes,
|
49
|
+
but there is no way to tell the asyncio thread to cancel its work.
|
50
|
+
A second KeyboardInterrupt makes ThreadPoolExecutor give up while the
|
51
|
+
asyncio thread still runs in the background, with terrible effects if it
|
52
|
+
prints to the user's terminal.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self) -> None:
|
56
|
+
self._lock = threading.Condition()
|
57
|
+
|
58
|
+
self._is_cancelled = False
|
59
|
+
self._started = False
|
60
|
+
self._done = False
|
61
|
+
|
62
|
+
self._loop: asyncio.AbstractEventLoop | None = None
|
63
|
+
self._cancel_event: asyncio.Event | None = None
|
64
|
+
|
65
|
+
def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
66
|
+
"""Run a coroutine in asyncio, cancelling it on `cancel()`.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
The result of the coroutine returned by `fn`.
|
70
|
+
|
71
|
+
Raises:
|
72
|
+
_RunnerCancelledError: If `cancel()` is called.
|
73
|
+
"""
|
74
|
+
return asyncio.run(self._run_or_cancel(fn))
|
75
|
+
|
76
|
+
async def _run_or_cancel(
|
77
|
+
self,
|
78
|
+
fn: Callable[[], Coroutine[Any, Any, _T]],
|
79
|
+
) -> _T:
|
80
|
+
with self._lock:
|
81
|
+
if self._is_cancelled:
|
82
|
+
raise _RunnerCancelledError()
|
83
|
+
|
84
|
+
self._loop = asyncio.get_running_loop()
|
85
|
+
self._cancel_event = asyncio.Event()
|
86
|
+
self._started = True
|
87
|
+
|
88
|
+
cancellation_task = asyncio.create_task(self._cancel_event.wait())
|
89
|
+
fn_task = asyncio.create_task(fn())
|
90
|
+
|
91
|
+
try:
|
92
|
+
await asyncio.wait(
|
93
|
+
[cancellation_task, fn_task],
|
94
|
+
return_when=asyncio.FIRST_COMPLETED,
|
95
|
+
)
|
96
|
+
|
97
|
+
if fn_task.done():
|
98
|
+
return fn_task.result()
|
99
|
+
else:
|
100
|
+
raise _RunnerCancelledError()
|
101
|
+
|
102
|
+
finally:
|
103
|
+
cancellation_task.cancel()
|
104
|
+
fn_task.cancel()
|
105
|
+
|
106
|
+
with self._lock:
|
107
|
+
self._done = True
|
108
|
+
|
109
|
+
def cancel(self) -> None:
|
110
|
+
"""Cancel all asyncio work started by `run()`."""
|
111
|
+
with self._lock:
|
112
|
+
if self._is_cancelled:
|
113
|
+
return
|
114
|
+
self._is_cancelled = True
|
115
|
+
|
116
|
+
if self._done or not self._started:
|
117
|
+
# If the runner already finished, no need to cancel it.
|
118
|
+
#
|
119
|
+
# If the runner hasn't started the loop yet, then it will not
|
120
|
+
# as we already set _is_cancelled.
|
121
|
+
return
|
122
|
+
|
123
|
+
assert self._loop
|
124
|
+
assert self._cancel_event
|
125
|
+
self._loop.call_soon_threadsafe(self._cancel_event.set)
|
126
|
+
|
127
|
+
|
128
|
+
class TaskGroup:
|
129
|
+
"""Object that `open_task_group()` yields."""
|
130
|
+
|
131
|
+
def __init__(self) -> None:
|
132
|
+
self._tasks: list[asyncio.Task] = []
|
133
|
+
|
134
|
+
def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
|
135
|
+
"""Schedule a task in the group.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
coro: The return value of the `async` function defining the task.
|
139
|
+
"""
|
140
|
+
self._tasks.append(asyncio.create_task(coro))
|
141
|
+
|
142
|
+
async def _wait_all(self) -> None:
|
143
|
+
"""Block until all tasks complete.
|
144
|
+
|
145
|
+
Raises:
|
146
|
+
Exception: If one or more tasks raises an exception, one of these
|
147
|
+
is raised arbitrarily.
|
148
|
+
"""
|
149
|
+
done, _ = await asyncio.wait(
|
150
|
+
self._tasks,
|
151
|
+
# NOTE: Cancelling a task counts as a normal exit,
|
152
|
+
# not an exception.
|
153
|
+
return_when=concurrent.futures.FIRST_EXCEPTION,
|
154
|
+
)
|
155
|
+
|
156
|
+
for task in done:
|
157
|
+
try:
|
158
|
+
if exc := task.exception():
|
159
|
+
raise exc
|
160
|
+
except asyncio.CancelledError:
|
161
|
+
pass
|
162
|
+
|
163
|
+
def _cancel_all(self) -> None:
|
164
|
+
"""Cancel all tasks."""
|
165
|
+
for task in self._tasks:
|
166
|
+
# NOTE: It is safe to cancel tasks that have already completed.
|
167
|
+
task.cancel()
|
168
|
+
|
169
|
+
|
170
|
+
@contextlib.asynccontextmanager
|
171
|
+
async def open_task_group() -> AsyncIterator[TaskGroup]:
|
172
|
+
"""Create a task group.
|
173
|
+
|
174
|
+
`asyncio` gained task groups in Python 3.11.
|
175
|
+
|
176
|
+
This is an async context manager, meant to be used with `async with`.
|
177
|
+
On exit, it blocks until all subtasks complete. If any subtask fails, or if
|
178
|
+
the current task is cancelled, it cancels all subtasks in the group and
|
179
|
+
raises the subtask's exception. If multiple subtasks fail simultaneously,
|
180
|
+
one of their exceptions is chosen arbitrarily.
|
181
|
+
|
182
|
+
NOTE: Subtask exceptions do not propagate until the context manager exits.
|
183
|
+
This means that the task group cannot cancel code running inside the
|
184
|
+
`async with` block .
|
185
|
+
"""
|
186
|
+
task_group = TaskGroup()
|
187
|
+
|
188
|
+
try:
|
189
|
+
yield task_group
|
190
|
+
await task_group._wait_all()
|
191
|
+
finally:
|
192
|
+
task_group._cancel_all()
|
193
|
+
|
194
|
+
|
195
|
+
@contextlib.contextmanager
|
196
|
+
def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
|
197
|
+
"""Schedule a task, cancelling it when exiting the context manager.
|
198
|
+
|
199
|
+
If the given coroutine raises an exception, that exception is raised
|
200
|
+
when exiting the context manager.
|
201
|
+
"""
|
202
|
+
task = asyncio.create_task(coro)
|
203
|
+
|
204
|
+
try:
|
205
|
+
yield
|
206
|
+
finally:
|
207
|
+
if task.done() and (exception := task.exception()):
|
208
|
+
raise exception
|
209
|
+
|
210
|
+
task.cancel()
|
@@ -0,0 +1,172 @@
|
|
1
|
+
"""Module for intercepting stdout/stderr.
|
2
|
+
|
3
|
+
This patches the `write()` method of `stdout` and `stderr` on import.
|
4
|
+
Once patched, it is not possible to unpatch or repatch, though individual
|
5
|
+
callbacks can be removed.
|
6
|
+
|
7
|
+
We assume that all other writing methods on the object delegate to `write()`,
|
8
|
+
like `writelines()`. This is not guaranteed to be true, but it is true for
|
9
|
+
common implementations. In particular, CPython's implementation of IOBase's
|
10
|
+
`writelines()` delegates to `write()`.
|
11
|
+
|
12
|
+
It is important to note that this technique interacts poorly with other
|
13
|
+
code that performs similar patching if it also allows unpatching as this
|
14
|
+
discards our modification. This is why we patch on import and do not support
|
15
|
+
unpatching:
|
16
|
+
|
17
|
+
with contextlib.redirect_stderr(...):
|
18
|
+
from ... import console_capture
|
19
|
+
# Here, everything works fine.
|
20
|
+
# Here, callbacks are never called again.
|
21
|
+
|
22
|
+
In particular, it does not work with some combinations of pytest's
|
23
|
+
`capfd` / `capsys` fixtures and pytest's `--capture` option.
|
24
|
+
"""
|
25
|
+
|
26
|
+
from __future__ import annotations
|
27
|
+
|
28
|
+
import sys
|
29
|
+
import threading
|
30
|
+
from typing import IO, AnyStr, Callable, Protocol
|
31
|
+
|
32
|
+
|
33
|
+
class CannotCaptureConsoleError(Exception):
|
34
|
+
"""The module failed to patch stdout or stderr."""
|
35
|
+
|
36
|
+
|
37
|
+
class _WriteCallback(Protocol):
|
38
|
+
"""A callback that receives intercepted bytes or string data."""
|
39
|
+
|
40
|
+
def __call__(
|
41
|
+
self,
|
42
|
+
data: bytes | str,
|
43
|
+
written: int,
|
44
|
+
/,
|
45
|
+
) -> None:
|
46
|
+
"""Intercept data passed to `write()`.
|
47
|
+
|
48
|
+
Args:
|
49
|
+
data: The object passed to stderr's or stdout's `write()`.
|
50
|
+
written: The number of bytes or characters written.
|
51
|
+
This is the return value of `write()`.
|
52
|
+
"""
|
53
|
+
|
54
|
+
|
55
|
+
_module_lock = threading.Lock()
|
56
|
+
|
57
|
+
_patch_exception: CannotCaptureConsoleError | None = None
|
58
|
+
|
59
|
+
_next_callback_id: int = 1
|
60
|
+
|
61
|
+
_stdout_callbacks: dict[int, _WriteCallback] = {}
|
62
|
+
_stderr_callbacks: dict[int, _WriteCallback] = {}
|
63
|
+
|
64
|
+
|
65
|
+
def capture_stdout(callback: _WriteCallback) -> Callable[[], None]:
|
66
|
+
"""Install a callback that runs after every write to sys.stdout.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
callback: A callback to invoke after running `sys.stdout.write`.
|
70
|
+
This may be called from any thread, so it must be thread-safe.
|
71
|
+
Exceptions are propagated to the caller of `write`.
|
72
|
+
See `_WriteCallback` for the exact protocol.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
A function to uninstall the callback.
|
76
|
+
|
77
|
+
Raises:
|
78
|
+
CannotCaptureConsoleError: If patching failed on import.
|
79
|
+
"""
|
80
|
+
with _module_lock:
|
81
|
+
if _patch_exception:
|
82
|
+
raise _patch_exception
|
83
|
+
|
84
|
+
return _insert_disposably(
|
85
|
+
_stdout_callbacks,
|
86
|
+
callback,
|
87
|
+
)
|
88
|
+
|
89
|
+
|
90
|
+
def capture_stderr(callback: _WriteCallback) -> Callable[[], None]:
|
91
|
+
"""Install a callback that runs after every write to sys.sdterr.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
callback: A callback to invoke after running `sys.stderr.write`.
|
95
|
+
This may be called from any thread, so it must be thread-safe.
|
96
|
+
Exceptions are propagated to the caller of `write`.
|
97
|
+
See `_WriteCallback` for the exact protocol.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
A function to uninstall the callback.
|
101
|
+
|
102
|
+
Raises:
|
103
|
+
CannotCaptureConsoleError: If patching failed on import.
|
104
|
+
"""
|
105
|
+
with _module_lock:
|
106
|
+
if _patch_exception:
|
107
|
+
raise _patch_exception
|
108
|
+
|
109
|
+
return _insert_disposably(
|
110
|
+
_stderr_callbacks,
|
111
|
+
callback,
|
112
|
+
)
|
113
|
+
|
114
|
+
|
115
|
+
def _insert_disposably(
|
116
|
+
callback_dict: dict[int, _WriteCallback],
|
117
|
+
callback: _WriteCallback,
|
118
|
+
) -> Callable[[], None]:
|
119
|
+
global _next_callback_id
|
120
|
+
id = _next_callback_id
|
121
|
+
_next_callback_id += 1
|
122
|
+
|
123
|
+
disposed = False
|
124
|
+
|
125
|
+
def dispose() -> None:
|
126
|
+
nonlocal disposed
|
127
|
+
|
128
|
+
with _module_lock:
|
129
|
+
if disposed:
|
130
|
+
return
|
131
|
+
|
132
|
+
del callback_dict[id]
|
133
|
+
|
134
|
+
disposed = True
|
135
|
+
|
136
|
+
callback_dict[id] = callback
|
137
|
+
return dispose
|
138
|
+
|
139
|
+
|
140
|
+
def _patch(
|
141
|
+
stdout_or_stderr: IO[AnyStr],
|
142
|
+
callbacks: dict[int, _WriteCallback],
|
143
|
+
) -> None:
|
144
|
+
orig_write: Callable[[AnyStr], int]
|
145
|
+
|
146
|
+
def write_with_callbacks(s: AnyStr, /) -> int:
|
147
|
+
n = orig_write(s)
|
148
|
+
|
149
|
+
# We make a copy here because callbacks could, in theory, modify
|
150
|
+
# the list of callbacks.
|
151
|
+
with _module_lock:
|
152
|
+
callbacks_copy = list(callbacks.values())
|
153
|
+
|
154
|
+
for cb in callbacks_copy:
|
155
|
+
cb(s, n)
|
156
|
+
|
157
|
+
return n
|
158
|
+
|
159
|
+
orig_write = stdout_or_stderr.write
|
160
|
+
|
161
|
+
# mypy==1.14.1 fails to type-check this:
|
162
|
+
# Incompatible types in assignment (expression has type
|
163
|
+
# "Callable[[bytes], int]", variable has type overloaded function)
|
164
|
+
stdout_or_stderr.write = write_with_callbacks # type: ignore
|
165
|
+
|
166
|
+
|
167
|
+
try:
|
168
|
+
_patch(sys.stdout, _stdout_callbacks)
|
169
|
+
_patch(sys.stderr, _stderr_callbacks)
|
170
|
+
except Exception as _patch_exception_cause:
|
171
|
+
_patch_exception = CannotCaptureConsoleError()
|
172
|
+
_patch_exception.__cause__ = _patch_exception_cause
|