wandb 0.19.6rc4__py3-none-any.whl → 0.19.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +1 -1
- wandb/__init__.pyi +25 -5
- wandb/apis/public/_generated/__init__.py +21 -0
- wandb/apis/public/_generated/base.py +128 -0
- wandb/apis/public/_generated/enums.py +4 -0
- wandb/apis/public/_generated/input_types.py +4 -0
- wandb/apis/public/_generated/operations.py +15 -0
- wandb/apis/public/_generated/server_features_query.py +27 -0
- wandb/apis/public/_generated/typing_compat.py +14 -0
- wandb/apis/public/api.py +192 -6
- wandb/apis/public/artifacts.py +13 -45
- wandb/apis/public/registries.py +573 -0
- wandb/apis/public/utils.py +36 -0
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +11 -20
- wandb/env.py +10 -0
- wandb/proto/v3/wandb_internal_pb2.py +243 -222
- wandb/proto/v3/wandb_server_pb2.py +4 -4
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +226 -222
- wandb/proto/v4/wandb_server_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/v5/wandb_internal_pb2.py +226 -222
- wandb/proto/v5/wandb_server_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +1 -1
- wandb/sdk/artifacts/_graphql_fragments.py +126 -0
- wandb/sdk/artifacts/artifact.py +43 -88
- wandb/sdk/backend/backend.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +14 -6
- wandb/sdk/data_types/helper_types/image_mask.py +12 -6
- wandb/sdk/data_types/saved_model.py +35 -46
- wandb/sdk/data_types/video.py +7 -16
- wandb/sdk/interface/interface.py +26 -10
- wandb/sdk/interface/interface_queue.py +5 -8
- wandb/sdk/interface/interface_relay.py +1 -6
- wandb/sdk/interface/interface_shared.py +21 -99
- wandb/sdk/interface/interface_sock.py +2 -13
- wandb/sdk/interface/router.py +21 -15
- wandb/sdk/interface/router_queue.py +2 -1
- wandb/sdk/interface/router_relay.py +2 -1
- wandb/sdk/interface/router_sock.py +5 -4
- wandb/sdk/internal/handler.py +4 -3
- wandb/sdk/internal/internal_api.py +12 -1
- wandb/sdk/internal/sender.py +0 -18
- wandb/sdk/lib/apikey.py +87 -26
- wandb/sdk/lib/asyncio_compat.py +210 -0
- wandb/sdk/lib/progress.py +78 -16
- wandb/sdk/lib/service_connection.py +1 -1
- wandb/sdk/lib/sock_client.py +7 -7
- wandb/sdk/mailbox/__init__.py +23 -0
- wandb/sdk/mailbox/handles.py +199 -0
- wandb/sdk/mailbox/mailbox.py +121 -0
- wandb/sdk/mailbox/wait_with_progress.py +134 -0
- wandb/sdk/service/server_sock.py +5 -1
- wandb/sdk/service/streams.py +66 -74
- wandb/sdk/verify/verify.py +54 -2
- wandb/sdk/wandb_init.py +61 -61
- wandb/sdk/wandb_login.py +7 -4
- wandb/sdk/wandb_metadata.py +65 -34
- wandb/sdk/wandb_require.py +14 -8
- wandb/sdk/wandb_run.py +82 -87
- wandb/sdk/wandb_settings.py +3 -3
- wandb/sdk/wandb_setup.py +19 -8
- wandb/sdk/wandb_sync.py +2 -4
- wandb/util.py +3 -1
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/METADATA +2 -2
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/RECORD +70 -57
- wandb/sdk/lib/mailbox.py +0 -442
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/WHEEL +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.19.6rc4.dist-info → wandb-0.19.7.dist-info}/licenses/LICENSE +0 -0
@@ -7,10 +7,10 @@ See interface.py for how interface classes relate to each other.
|
|
7
7
|
import logging
|
8
8
|
from typing import TYPE_CHECKING, Any, Optional
|
9
9
|
|
10
|
-
from
|
10
|
+
from wandb.sdk.mailbox import Mailbox
|
11
|
+
|
11
12
|
from ..lib.sock_client import SockClient
|
12
13
|
from .interface_shared import InterfaceShared
|
13
|
-
from .message_future import MessageFuture
|
14
14
|
from .router_sock import MessageSockRouter
|
15
15
|
|
16
16
|
if TYPE_CHECKING:
|
@@ -32,7 +32,6 @@ class InterfaceSock(InterfaceShared):
|
|
32
32
|
# _sock_client is used when abstract method _init_router() is called by constructor
|
33
33
|
self._sock_client = sock_client
|
34
34
|
super().__init__(mailbox=mailbox)
|
35
|
-
self._process_check = False
|
36
35
|
self._stream_id = stream_id
|
37
36
|
|
38
37
|
def _init_router(self) -> None:
|
@@ -45,13 +44,3 @@ class InterfaceSock(InterfaceShared):
|
|
45
44
|
def _publish(self, record: "pb.Record", local: Optional[bool] = None) -> None:
|
46
45
|
self._assign(record)
|
47
46
|
self._sock_client.send_record_publish(record)
|
48
|
-
|
49
|
-
def _communicate_async(
|
50
|
-
self, rec: "pb.Record", local: Optional[bool] = None
|
51
|
-
) -> MessageFuture:
|
52
|
-
self._assign(rec)
|
53
|
-
assert self._router
|
54
|
-
if self._process_check and self._process and not self._process.is_alive():
|
55
|
-
raise Exception("The wandb backend process has shutdown")
|
56
|
-
future = self._router.send_and_receive(rec, local=local)
|
57
|
-
return future
|
wandb/sdk/interface/router.py
CHANGED
@@ -10,7 +10,8 @@ import uuid
|
|
10
10
|
from abc import abstractmethod
|
11
11
|
from typing import TYPE_CHECKING, Dict, Optional
|
12
12
|
|
13
|
-
from
|
13
|
+
from wandb.sdk import mailbox
|
14
|
+
|
14
15
|
from .message_future import MessageFuture
|
15
16
|
|
16
17
|
if TYPE_CHECKING:
|
@@ -63,20 +64,25 @@ class MessageRouter:
|
|
63
64
|
raise NotImplementedError
|
64
65
|
|
65
66
|
def message_loop(self) -> None:
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
67
|
+
try:
|
68
|
+
while not self._join_event.is_set():
|
69
|
+
try:
|
70
|
+
msg = self._read_message()
|
71
|
+
except EOFError:
|
72
|
+
# On abnormal shutdown the queue will be destroyed underneath
|
73
|
+
# resulting in EOFError. message_loop needs to exit..
|
74
|
+
logger.warning("EOFError seen in message_loop")
|
75
|
+
break
|
76
|
+
except MessageRouterClosedError as e:
|
77
|
+
logger.warning("message_loop has been closed", exc_info=e)
|
78
|
+
break
|
79
|
+
if not msg:
|
80
|
+
continue
|
81
|
+
self._handle_msg_rcv(msg)
|
82
|
+
|
83
|
+
finally:
|
84
|
+
if self._mailbox:
|
85
|
+
self._mailbox.close()
|
80
86
|
|
81
87
|
def send_and_receive(
|
82
88
|
self, rec: "pb.Record", local: Optional[bool] = None
|
@@ -6,8 +6,9 @@ Router to manage responses from a socket client.
|
|
6
6
|
|
7
7
|
from typing import TYPE_CHECKING, Optional
|
8
8
|
|
9
|
-
from
|
10
|
-
from
|
9
|
+
from wandb.sdk.lib.sock_client import SockClient, SockClientClosedError
|
10
|
+
from wandb.sdk.mailbox import Mailbox
|
11
|
+
|
11
12
|
from .router import MessageRouter, MessageRouterClosedError
|
12
13
|
|
13
14
|
if TYPE_CHECKING:
|
@@ -25,8 +26,8 @@ class MessageSockRouter(MessageRouter):
|
|
25
26
|
def _read_message(self) -> Optional["pb.Result"]:
|
26
27
|
try:
|
27
28
|
resp = self._sock_client.read_server_response(timeout=1)
|
28
|
-
except SockClientClosedError:
|
29
|
-
raise MessageRouterClosedError
|
29
|
+
except SockClientClosedError as e:
|
30
|
+
raise MessageRouterClosedError from e
|
30
31
|
if not resp:
|
31
32
|
return None
|
32
33
|
msg = resp.result_communicate
|
wandb/sdk/internal/handler.py
CHANGED
@@ -205,9 +205,6 @@ class HandleManager:
|
|
205
205
|
# defer is used to drive the sender finish state machine
|
206
206
|
self._dispatch_record(record, always_send=True)
|
207
207
|
|
208
|
-
def handle_request_login(self, record: Record) -> None:
|
209
|
-
self._dispatch_record(record)
|
210
|
-
|
211
208
|
def handle_request_python_packages(self, record: Record) -> None:
|
212
209
|
self._dispatch_record(record)
|
213
210
|
|
@@ -892,6 +889,10 @@ class HandleManager:
|
|
892
889
|
self._respond_result(result)
|
893
890
|
self._stopped.set()
|
894
891
|
|
892
|
+
def handle_request_operations(self, record: Record) -> None:
|
893
|
+
"""No-op. Not implemented for the legacy-service."""
|
894
|
+
self._respond_result(proto_util._result_from_record(record))
|
895
|
+
|
895
896
|
def finish(self) -> None:
|
896
897
|
logger.info("shutting down handler")
|
897
898
|
if self._system_monitor is not None:
|
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
|
|
115
115
|
root_dir: Optional[str]
|
116
116
|
api_key: Optional[str]
|
117
117
|
entity: Optional[str]
|
118
|
+
organization: Optional[str]
|
118
119
|
project: Optional[str]
|
119
120
|
_extra_http_headers: Optional[Mapping[str, str]]
|
120
121
|
_proxies: Optional[Mapping[str, str]]
|
@@ -256,6 +257,7 @@ class Api:
|
|
256
257
|
"root_dir": None,
|
257
258
|
"api_key": None,
|
258
259
|
"entity": None,
|
260
|
+
"organization": None,
|
259
261
|
"project": None,
|
260
262
|
"_extra_http_headers": None,
|
261
263
|
"_proxies": None,
|
@@ -489,7 +491,8 @@ class Api:
|
|
489
491
|
{
|
490
492
|
"entity": "models",
|
491
493
|
"base_url": "https://api.wandb.ai",
|
492
|
-
"project": None
|
494
|
+
"project": None,
|
495
|
+
"organization": "my-org",
|
493
496
|
}
|
494
497
|
"""
|
495
498
|
result = self.default_settings.copy()
|
@@ -504,6 +507,14 @@ class Api:
|
|
504
507
|
),
|
505
508
|
env=self._environ,
|
506
509
|
),
|
510
|
+
"organization": env.get_organization(
|
511
|
+
self._settings.get(
|
512
|
+
Settings.DEFAULT_SECTION,
|
513
|
+
"organization",
|
514
|
+
fallback=result.get("organization"),
|
515
|
+
),
|
516
|
+
env=self._environ,
|
517
|
+
),
|
507
518
|
"project": env.get_project(
|
508
519
|
self._settings.get(
|
509
520
|
Settings.DEFAULT_SECTION,
|
wandb/sdk/internal/sender.py
CHANGED
@@ -544,23 +544,6 @@ class SendManager:
|
|
544
544
|
logger.warning(f"Error emptying retry queue: {e}")
|
545
545
|
self._respond_result(result)
|
546
546
|
|
547
|
-
def send_request_login(self, record: "Record") -> None:
|
548
|
-
# TODO: do something with api_key or anonymous?
|
549
|
-
# TODO: return an error if we aren't logged in?
|
550
|
-
self._api.reauth()
|
551
|
-
viewer = self.get_viewer_info()
|
552
|
-
server_info = self.get_server_info()
|
553
|
-
# self._login_flags = json.loads(viewer.get("flags", "{}"))
|
554
|
-
# self._login_entity = viewer.get("entity")
|
555
|
-
if server_info:
|
556
|
-
logger.info(f"Login server info: {server_info}")
|
557
|
-
self._entity = viewer.get("entity")
|
558
|
-
if record.control.req_resp:
|
559
|
-
result = proto_util._result_from_record(record)
|
560
|
-
if self._entity:
|
561
|
-
result.response.login_response.active_entity = self._entity
|
562
|
-
self._respond_result(result)
|
563
|
-
|
564
547
|
def send_exit(self, record: "Record") -> None:
|
565
548
|
# track where the exit came from
|
566
549
|
self._record_exit = record
|
@@ -1491,7 +1474,6 @@ class SendManager:
|
|
1491
1474
|
self._job_builder.set_partial_source_id(use.id)
|
1492
1475
|
|
1493
1476
|
def send_request_log_artifact(self, record: "Record") -> None:
|
1494
|
-
assert record.control.req_resp
|
1495
1477
|
result = proto_util._result_from_record(record)
|
1496
1478
|
artifact = record.request.log_artifact.artifact
|
1497
1479
|
history_step = record.request.log_artifact.history_step
|
wandb/sdk/lib/apikey.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1
1
|
"""apikey util."""
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import dataclasses
|
3
6
|
import os
|
4
7
|
import platform
|
5
8
|
import stat
|
@@ -32,8 +35,21 @@ LOGIN_CHOICES = [
|
|
32
35
|
LOGIN_CHOICE_DRYRUN,
|
33
36
|
]
|
34
37
|
|
38
|
+
|
39
|
+
@dataclasses.dataclass(frozen=True)
|
40
|
+
class _NetrcPermissions:
|
41
|
+
exists: bool
|
42
|
+
read_access: bool
|
43
|
+
write_access: bool
|
44
|
+
|
45
|
+
|
46
|
+
class WriteNetrcError(Exception):
|
47
|
+
"""Raised when we cannot write to the netrc file."""
|
48
|
+
|
49
|
+
|
35
50
|
Mode = Literal["allow", "must", "never", "false", "true"]
|
36
51
|
|
52
|
+
|
37
53
|
if TYPE_CHECKING:
|
38
54
|
from wandb.sdk.wandb_settings import Settings
|
39
55
|
|
@@ -170,29 +186,64 @@ def prompt_api_key( # noqa: C901
|
|
170
186
|
return key
|
171
187
|
|
172
188
|
|
173
|
-
def
|
189
|
+
def check_netrc_access(
|
190
|
+
netrc_path: str,
|
191
|
+
) -> _NetrcPermissions:
|
192
|
+
"""Check if we can read and write to the netrc file."""
|
193
|
+
file_exists = False
|
194
|
+
write_access = False
|
195
|
+
read_access = False
|
196
|
+
try:
|
197
|
+
st = os.stat(netrc_path)
|
198
|
+
file_exists = True
|
199
|
+
write_access = bool(st.st_mode & stat.S_IWUSR)
|
200
|
+
read_access = bool(st.st_mode & stat.S_IRUSR)
|
201
|
+
except FileNotFoundError:
|
202
|
+
# If the netrc file doesn't exist, we will create it.
|
203
|
+
write_access = True
|
204
|
+
read_access = True
|
205
|
+
except OSError as e:
|
206
|
+
wandb.termerror(f"Unable to read permissions for {netrc_path}, {e}")
|
207
|
+
|
208
|
+
return _NetrcPermissions(
|
209
|
+
exists=file_exists,
|
210
|
+
write_access=write_access,
|
211
|
+
read_access=read_access,
|
212
|
+
)
|
213
|
+
|
214
|
+
|
215
|
+
def write_netrc(host: str, entity: str, key: str):
|
174
216
|
"""Add our host and key to .netrc."""
|
175
217
|
_, key_suffix = key.split("-", 1) if "-" in key else ("", key)
|
176
218
|
if len(key_suffix) != 40:
|
177
|
-
|
219
|
+
raise ValueError(
|
178
220
|
"API-key must be exactly 40 characters long: {} ({} chars)".format(
|
179
221
|
key_suffix, len(key_suffix)
|
180
222
|
)
|
181
223
|
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
224
|
+
|
225
|
+
normalized_host = urlparse(host).netloc
|
226
|
+
netrc_path = get_netrc_file_path()
|
227
|
+
netrc_access = check_netrc_access(netrc_path)
|
228
|
+
|
229
|
+
if not netrc_access.write_access or not netrc_access.read_access:
|
230
|
+
raise WriteNetrcError(
|
231
|
+
f"Cannot access {netrc_path}. In order to persist your API key, "
|
232
|
+
"grant read and write permissions for your user to the file "
|
233
|
+
'or specify a different file with the environment variable "NETRC=<new_netrc_path>".'
|
188
234
|
)
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
235
|
+
|
236
|
+
machine_line = f"machine {normalized_host}"
|
237
|
+
orig_lines = None
|
238
|
+
try:
|
239
|
+
with open(netrc_path) as f:
|
240
|
+
orig_lines = f.read().strip().split("\n")
|
241
|
+
except FileNotFoundError:
|
242
|
+
wandb.termlog("No netrc file found, creating one.")
|
243
|
+
except OSError as e:
|
244
|
+
raise WriteNetrcError(f"Unable to read {netrc_path}") from e
|
245
|
+
|
246
|
+
try:
|
196
247
|
with open(netrc_path, "w") as f:
|
197
248
|
if orig_lines:
|
198
249
|
# delete this machine from the file if it's already there.
|
@@ -206,20 +257,22 @@ def write_netrc(host: str, entity: str, key: str) -> Optional[bool]:
|
|
206
257
|
skip -= 1
|
207
258
|
else:
|
208
259
|
f.write("{}\n".format(line))
|
260
|
+
|
261
|
+
wandb.termlog(
|
262
|
+
f"Appending key for {normalized_host} to your netrc file: {netrc_path}"
|
263
|
+
)
|
209
264
|
f.write(
|
210
265
|
textwrap.dedent(
|
211
266
|
"""\
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
267
|
+
machine {host}
|
268
|
+
login {entity}
|
269
|
+
password {key}
|
270
|
+
"""
|
216
271
|
).format(host=normalized_host, entity=entity, key=key)
|
217
272
|
)
|
218
273
|
os.chmod(netrc_path, stat.S_IRUSR | stat.S_IWUSR)
|
219
|
-
|
220
|
-
|
221
|
-
wandb.termerror(f"Unable to read {netrc_path}")
|
222
|
-
return None
|
274
|
+
except OSError as e:
|
275
|
+
raise WriteNetrcError(f"Unable to write {netrc_path}") from e
|
223
276
|
|
224
277
|
|
225
278
|
def write_key(
|
@@ -250,7 +303,15 @@ def api_key(settings: Optional["Settings"] = None) -> Optional[str]:
|
|
250
303
|
settings = wandb.setup().settings
|
251
304
|
if settings.api_key:
|
252
305
|
return settings.api_key
|
253
|
-
|
254
|
-
|
255
|
-
|
306
|
+
|
307
|
+
netrc_access = check_netrc_access(get_netrc_file_path())
|
308
|
+
if netrc_access.exists and not netrc_access.read_access:
|
309
|
+
wandb.termwarn(f"Cannot access {get_netrc_file_path()}.")
|
310
|
+
return None
|
311
|
+
|
312
|
+
if netrc_access.exists:
|
313
|
+
auth = get_netrc_auth(settings.base_url)
|
314
|
+
if auth:
|
315
|
+
return auth[-1]
|
316
|
+
|
256
317
|
return None
|
@@ -0,0 +1,210 @@
|
|
1
|
+
"""Functions for compatibility with asyncio."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import concurrent
|
7
|
+
import concurrent.futures
|
8
|
+
import contextlib
|
9
|
+
import threading
|
10
|
+
from typing import Any, AsyncIterator, Callable, Coroutine, Iterator, TypeVar
|
11
|
+
|
12
|
+
_T = TypeVar("_T")
|
13
|
+
|
14
|
+
|
15
|
+
def run(fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
16
|
+
"""Run `fn` in an asyncio loop in a new thread.
|
17
|
+
|
18
|
+
This must always be used instead of `asyncio.run` which fails if there is
|
19
|
+
an active `asyncio` event loop in the current thread. Since `wandb` was not
|
20
|
+
originally designed with `asyncio` in mind, using `asyncio.run` would break
|
21
|
+
users who were calling `wandb` methods from an `asyncio` loop.
|
22
|
+
|
23
|
+
Note that due to starting a new thread, this is slightly slow.
|
24
|
+
"""
|
25
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
26
|
+
runner = _Runner()
|
27
|
+
future = executor.submit(runner.run, fn)
|
28
|
+
|
29
|
+
try:
|
30
|
+
return future.result()
|
31
|
+
|
32
|
+
finally:
|
33
|
+
runner.cancel()
|
34
|
+
|
35
|
+
|
36
|
+
class _RunnerCancelledError(Exception):
|
37
|
+
"""The `_Runner.run()` invocation was cancelled."""
|
38
|
+
|
39
|
+
|
40
|
+
class _Runner:
|
41
|
+
"""Runs an asyncio event loop allowing cancellation.
|
42
|
+
|
43
|
+
This is like `asyncio.run()`, except it provides a `cancel()` method
|
44
|
+
meant to be called in a `finally` block.
|
45
|
+
|
46
|
+
Without this, it is impossible to make `asyncio.run()` stop if it runs
|
47
|
+
in a non-main thread. In particular, a KeyboardInterrupt causes the
|
48
|
+
ThreadPoolExecutor above to block until the asyncio thread completes,
|
49
|
+
but there is no way to tell the asyncio thread to cancel its work.
|
50
|
+
A second KeyboardInterrupt makes ThreadPoolExecutor give up while the
|
51
|
+
asyncio thread still runs in the background, with terrible effects if it
|
52
|
+
prints to the user's terminal.
|
53
|
+
"""
|
54
|
+
|
55
|
+
def __init__(self) -> None:
|
56
|
+
self._lock = threading.Condition()
|
57
|
+
|
58
|
+
self._is_cancelled = False
|
59
|
+
self._started = False
|
60
|
+
self._done = False
|
61
|
+
|
62
|
+
self._loop: asyncio.AbstractEventLoop | None = None
|
63
|
+
self._cancel_event: asyncio.Event | None = None
|
64
|
+
|
65
|
+
def run(self, fn: Callable[[], Coroutine[Any, Any, _T]]) -> _T:
|
66
|
+
"""Run a coroutine in asyncio, cancelling it on `cancel()`.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
The result of the coroutine returned by `fn`.
|
70
|
+
|
71
|
+
Raises:
|
72
|
+
_RunnerCancelledError: If `cancel()` is called.
|
73
|
+
"""
|
74
|
+
return asyncio.run(self._run_or_cancel(fn))
|
75
|
+
|
76
|
+
async def _run_or_cancel(
|
77
|
+
self,
|
78
|
+
fn: Callable[[], Coroutine[Any, Any, _T]],
|
79
|
+
) -> _T:
|
80
|
+
with self._lock:
|
81
|
+
if self._is_cancelled:
|
82
|
+
raise _RunnerCancelledError()
|
83
|
+
|
84
|
+
self._loop = asyncio.get_running_loop()
|
85
|
+
self._cancel_event = asyncio.Event()
|
86
|
+
self._started = True
|
87
|
+
|
88
|
+
cancellation_task = asyncio.create_task(self._cancel_event.wait())
|
89
|
+
fn_task = asyncio.create_task(fn())
|
90
|
+
|
91
|
+
try:
|
92
|
+
await asyncio.wait(
|
93
|
+
[cancellation_task, fn_task],
|
94
|
+
return_when=asyncio.FIRST_COMPLETED,
|
95
|
+
)
|
96
|
+
|
97
|
+
if fn_task.done():
|
98
|
+
return fn_task.result()
|
99
|
+
else:
|
100
|
+
raise _RunnerCancelledError()
|
101
|
+
|
102
|
+
finally:
|
103
|
+
cancellation_task.cancel()
|
104
|
+
fn_task.cancel()
|
105
|
+
|
106
|
+
with self._lock:
|
107
|
+
self._done = True
|
108
|
+
|
109
|
+
def cancel(self) -> None:
|
110
|
+
"""Cancel all asyncio work started by `run()`."""
|
111
|
+
with self._lock:
|
112
|
+
if self._is_cancelled:
|
113
|
+
return
|
114
|
+
self._is_cancelled = True
|
115
|
+
|
116
|
+
if self._done or not self._started:
|
117
|
+
# If the runner already finished, no need to cancel it.
|
118
|
+
#
|
119
|
+
# If the runner hasn't started the loop yet, then it will not
|
120
|
+
# as we already set _is_cancelled.
|
121
|
+
return
|
122
|
+
|
123
|
+
assert self._loop
|
124
|
+
assert self._cancel_event
|
125
|
+
self._loop.call_soon_threadsafe(self._cancel_event.set)
|
126
|
+
|
127
|
+
|
128
|
+
class TaskGroup:
|
129
|
+
"""Object that `open_task_group()` yields."""
|
130
|
+
|
131
|
+
def __init__(self) -> None:
|
132
|
+
self._tasks: list[asyncio.Task] = []
|
133
|
+
|
134
|
+
def start_soon(self, coro: Coroutine[Any, Any, Any]) -> None:
|
135
|
+
"""Schedule a task in the group.
|
136
|
+
|
137
|
+
Args:
|
138
|
+
coro: The return value of the `async` function defining the task.
|
139
|
+
"""
|
140
|
+
self._tasks.append(asyncio.create_task(coro))
|
141
|
+
|
142
|
+
async def _wait_all(self) -> None:
|
143
|
+
"""Block until all tasks complete.
|
144
|
+
|
145
|
+
Raises:
|
146
|
+
Exception: If one or more tasks raises an exception, one of these
|
147
|
+
is raised arbitrarily.
|
148
|
+
"""
|
149
|
+
done, _ = await asyncio.wait(
|
150
|
+
self._tasks,
|
151
|
+
# NOTE: Cancelling a task counts as a normal exit,
|
152
|
+
# not an exception.
|
153
|
+
return_when=concurrent.futures.FIRST_EXCEPTION,
|
154
|
+
)
|
155
|
+
|
156
|
+
for task in done:
|
157
|
+
try:
|
158
|
+
if exc := task.exception():
|
159
|
+
raise exc
|
160
|
+
except asyncio.CancelledError:
|
161
|
+
pass
|
162
|
+
|
163
|
+
def _cancel_all(self) -> None:
|
164
|
+
"""Cancel all tasks."""
|
165
|
+
for task in self._tasks:
|
166
|
+
# NOTE: It is safe to cancel tasks that have already completed.
|
167
|
+
task.cancel()
|
168
|
+
|
169
|
+
|
170
|
+
@contextlib.asynccontextmanager
|
171
|
+
async def open_task_group() -> AsyncIterator[TaskGroup]:
|
172
|
+
"""Create a task group.
|
173
|
+
|
174
|
+
`asyncio` gained task groups in Python 3.11.
|
175
|
+
|
176
|
+
This is an async context manager, meant to be used with `async with`.
|
177
|
+
On exit, it blocks until all subtasks complete. If any subtask fails, or if
|
178
|
+
the current task is cancelled, it cancels all subtasks in the group and
|
179
|
+
raises the subtask's exception. If multiple subtasks fail simultaneously,
|
180
|
+
one of their exceptions is chosen arbitrarily.
|
181
|
+
|
182
|
+
NOTE: Subtask exceptions do not propagate until the context manager exits.
|
183
|
+
This means that the task group cannot cancel code running inside the
|
184
|
+
`async with` block .
|
185
|
+
"""
|
186
|
+
task_group = TaskGroup()
|
187
|
+
|
188
|
+
try:
|
189
|
+
yield task_group
|
190
|
+
await task_group._wait_all()
|
191
|
+
finally:
|
192
|
+
task_group._cancel_all()
|
193
|
+
|
194
|
+
|
195
|
+
@contextlib.contextmanager
|
196
|
+
def cancel_on_exit(coro: Coroutine[Any, Any, Any]) -> Iterator[None]:
|
197
|
+
"""Schedule a task, cancelling it when exiting the context manager.
|
198
|
+
|
199
|
+
If the given coroutine raises an exception, that exception is raised
|
200
|
+
when exiting the context manager.
|
201
|
+
"""
|
202
|
+
task = asyncio.create_task(coro)
|
203
|
+
|
204
|
+
try:
|
205
|
+
yield
|
206
|
+
finally:
|
207
|
+
if task.done() and (exception := task.exception()):
|
208
|
+
raise exception
|
209
|
+
|
210
|
+
task.cancel()
|